diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5aa170b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build-and-test: + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-24.04] + compiler: [g++, clang++] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + + - name: Build and test + env: + CXX: ${{ matrix.compiler }} + run: make -f Makefile.new clean && make -f Makefile.new all + + macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Build and test + run: make -f Makefile.new clean && make -f Makefile.new all diff --git a/.gitignore b/.gitignore index 1c019c9..002a119 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,7 @@ src/*_parser/*_lexer.yy.c src/*_parser/*_parser.output src/*_parser/*_parser.report +# New parser build artifacts +libsqlparser.a +run_tests +run_bench diff --git a/Makefile.new b/Makefile.new new file mode 100644 index 0000000..2b31175 --- /dev/null +++ b/Makefile.new @@ -0,0 +1,93 @@ +CXX = g++ +CXXFLAGS = -std=c++17 -Wall -Wextra -g -O2 +CPPFLAGS = -I./include -I./third_party/googletest/googletest/include + +PROJECT_ROOT = . +SRC_DIR = $(PROJECT_ROOT)/src/sql_parser +INCLUDE_DIR = $(PROJECT_ROOT)/include/sql_parser +TEST_DIR = $(PROJECT_ROOT)/tests + +# Library sources +LIB_SRCS = $(SRC_DIR)/arena.cpp $(SRC_DIR)/parser.cpp +LIB_OBJS = $(LIB_SRCS:.cpp=.o) +LIB_TARGET = $(PROJECT_ROOT)/libsqlparser.a + +# Google Test library +GTEST_DIR = $(PROJECT_ROOT)/third_party/googletest/googletest +GTEST_SRC = $(GTEST_DIR)/src/gtest-all.cc +GTEST_OBJ = $(GTEST_DIR)/src/gtest-all.o +GTEST_CPPFLAGS = -I$(GTEST_DIR)/include -I$(GTEST_DIR) + +# Test sources +TEST_SRCS = $(TEST_DIR)/test_main.cpp \ + $(TEST_DIR)/test_arena.cpp \ + $(TEST_DIR)/test_tokenizer.cpp \ + $(TEST_DIR)/test_classifier.cpp \ + $(TEST_DIR)/test_expression.cpp \ + $(TEST_DIR)/test_set.cpp \ + $(TEST_DIR)/test_select.cpp \ + $(TEST_DIR)/test_emitter.cpp \ + $(TEST_DIR)/test_stmt_cache.cpp \ + $(TEST_DIR)/test_insert.cpp \ + $(TEST_DIR)/test_update.cpp \ + $(TEST_DIR)/test_delete.cpp \ + $(TEST_DIR)/test_compound.cpp \ + $(TEST_DIR)/test_digest.cpp +TEST_OBJS = $(TEST_SRCS:.cpp=.o) +TEST_TARGET = $(PROJECT_ROOT)/run_tests + +# Google Benchmark +GBENCH_DIR = $(PROJECT_ROOT)/third_party/benchmark +GBENCH_SRCS = $(filter-out $(GBENCH_DIR)/src/benchmark_main.cc, $(wildcard $(GBENCH_DIR)/src/*.cc)) +GBENCH_OBJS = $(GBENCH_SRCS:.cc=.o) +GBENCH_CPPFLAGS = -I$(GBENCH_DIR)/include -I$(GBENCH_DIR)/src -DHAVE_STD_REGEX -DHAVE_STEADY_CLOCK + +BENCH_DIR = $(PROJECT_ROOT)/bench +BENCH_SRCS = $(BENCH_DIR)/bench_main.cpp $(BENCH_DIR)/bench_parser.cpp +BENCH_OBJS = $(BENCH_SRCS:.cpp=.o) +BENCH_TARGET = $(PROJECT_ROOT)/run_bench + +.PHONY: all lib test bench clean + +all: lib test + +lib: $(LIB_TARGET) + +$(LIB_TARGET): $(LIB_OBJS) + ar rcs $@ $^ + @echo "Built $@" + +$(SRC_DIR)/%.o: $(SRC_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) -c $< -o $@ + +# Google Test object +$(GTEST_OBJ): $(GTEST_SRC) + $(CXX) $(CXXFLAGS) $(GTEST_CPPFLAGS) -c $< -o $@ + +# Test objects +$(TEST_DIR)/%.o: $(TEST_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(GTEST_CPPFLAGS) -c $< -o $@ + +test: $(TEST_TARGET) + ./$(TEST_TARGET) + +$(TEST_TARGET): $(TEST_OBJS) $(GTEST_OBJ) $(LIB_TARGET) + $(CXX) $(CXXFLAGS) -o $@ $(TEST_OBJS) $(GTEST_OBJ) -L$(PROJECT_ROOT) -lsqlparser -lpthread + +# Benchmark objects +$(GBENCH_DIR)/src/%.o: $(GBENCH_DIR)/src/%.cc + $(CXX) $(CXXFLAGS) $(GBENCH_CPPFLAGS) -c $< -o $@ + +$(BENCH_DIR)/%.o: $(BENCH_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(GBENCH_CPPFLAGS) -c $< -o $@ + +bench: $(BENCH_TARGET) + ./$(BENCH_TARGET) --benchmark_format=console + +$(BENCH_TARGET): $(BENCH_OBJS) $(GBENCH_OBJS) $(LIB_TARGET) + $(CXX) $(CXXFLAGS) -o $@ $(BENCH_OBJS) $(GBENCH_OBJS) -L$(PROJECT_ROOT) -lsqlparser -lpthread + +clean: + rm -f $(LIB_OBJS) $(LIB_TARGET) $(TEST_OBJS) $(GTEST_OBJ) $(TEST_TARGET) + rm -f $(BENCH_OBJS) $(GBENCH_OBJS) $(BENCH_TARGET) + @echo "Cleaned." diff --git a/bench/bench_main.cpp b/bench/bench_main.cpp new file mode 100644 index 0000000..71fefa0 --- /dev/null +++ b/bench/bench_main.cpp @@ -0,0 +1,3 @@ +#include + +BENCHMARK_MAIN(); diff --git a/bench/bench_parser.cpp b/bench/bench_parser.cpp new file mode 100644 index 0000000..8ac22a1 --- /dev/null +++ b/bench/bench_parser.cpp @@ -0,0 +1,239 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +// ========== Tier 2: Classification ========== +// Target: <100ns + +static void BM_Classify_Insert(benchmark::State& state) { + Parser parser; + const char* sql = "INSERT INTO users VALUES (1, 'name', 'email')"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Insert); + +static void BM_Classify_Update(benchmark::State& state) { + Parser parser; + const char* sql = "UPDATE users SET name = 'x' WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Update); + +static void BM_Classify_Delete(benchmark::State& state) { + Parser parser; + const char* sql = "DELETE FROM users WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Delete); + +static void BM_Classify_Begin(benchmark::State& state) { + Parser parser; + const char* sql = "BEGIN"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Begin); + +// ========== Tier 1: SET parse ========== +// Target: <300ns + +static void BM_Set_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SET @@session.wait_timeout = 600"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_Simple); + +static void BM_Set_Names(benchmark::State& state) { + Parser parser; + const char* sql = "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_Names); + +static void BM_Set_MultiVar(benchmark::State& state) { + Parser parser; + const char* sql = "SET autocommit = 1, wait_timeout = 28800, sql_mode = 'STRICT_TRANS_TABLES'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_MultiVar); + +static void BM_Set_FunctionRHS(benchmark::State& state) { + Parser parser; + const char* sql = "SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_FunctionRHS); + +// ========== Tier 1: SELECT parse ========== +// Target: <500ns simple, <2us complex + +static void BM_Select_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT col FROM t WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Simple); + +static void BM_Select_MultiColumn(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT id, name, email, status FROM users WHERE active = 1 ORDER BY name LIMIT 100"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_MultiColumn); + +static void BM_Select_Join(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT u.id, o.total FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'active'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Join); + +static void BM_Select_Complex(benchmark::State& state) { + Parser parser; + const char* sql = + "SELECT u.id, u.name, COUNT(o.id) AS order_count " + "FROM users u " + "LEFT JOIN orders o ON u.id = o.user_id " + "WHERE u.status = 'active' AND u.created_at > '2024-01-01' " + "GROUP BY u.id, u.name " + "HAVING COUNT(o.id) > 5 " + "ORDER BY order_count DESC " + "LIMIT 50 OFFSET 10"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Complex); + +static void BM_Select_MultiJoin(benchmark::State& state) { + Parser parser; + const char* sql = + "SELECT a.id, b.name, c.value, d.total " + "FROM t1 a " + "JOIN t2 b ON a.id = b.a_id " + "LEFT JOIN t3 c ON b.id = c.b_id " + "JOIN t4 d ON c.id = d.c_id " + "WHERE a.status = 1 AND d.total > 100 " + "ORDER BY d.total DESC " + "LIMIT 20"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_MultiJoin); + +// ========== Query Reconstruction (round-trip) ========== +// Target: <500ns + +static void BM_Emit_SetSimple(benchmark::State& state) { + Parser parser; + const char* sql = "SET autocommit = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + benchmark::DoNotOptimize(emitter.result()); + } +} +BENCHMARK(BM_Emit_SetSimple); + +static void BM_Emit_SelectSimple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT * FROM users WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + benchmark::DoNotOptimize(emitter.result()); + } +} +BENCHMARK(BM_Emit_SelectSimple); + +// ========== Arena reset ========== +// Target: <10ns + +static void BM_ArenaReset(benchmark::State& state) { + Arena arena(65536); + for (auto _ : state) { + arena.allocate(256); // allocate something + arena.reset(); + benchmark::DoNotOptimize(arena.bytes_used()); + } +} +BENCHMARK(BM_ArenaReset); + +// ========== PostgreSQL ========== + +static void BM_PgSQL_Select_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT col FROM t WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_PgSQL_Select_Simple); + +static void BM_PgSQL_Set_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SET work_mem = '256MB'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_PgSQL_Set_Simple); diff --git a/docs/superpowers/plans/2026-03-24-benchmarks.md b/docs/superpowers/plans/2026-03-24-benchmarks.md new file mode 100644 index 0000000..a8102e9 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-benchmarks.md @@ -0,0 +1,372 @@ +# Performance Benchmarks Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add Google Benchmark-based performance tests to validate the parser meets its latency targets and catch performance regressions. + +**Architecture:** Benchmarks use Google Benchmark (vendored alongside Google Test). Each benchmark measures a specific operation in isolation: Tier 2 classification, Tier 1 SET parse, Tier 1 SELECT parse (simple and complex), query reconstruction (round-trip), and arena reset. A single parser instance is reused across iterations (matching ProxySQL's per-thread usage pattern). + +**Tech Stack:** C++17, Google Benchmark + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` (Performance Targets section) + +--- + +## Scope + +1. Vendor Google Benchmark +2. Benchmark targets matching the spec: + - Tier 2 classification: <100ns + - Tier 1 SET parse: <300ns + - Tier 1 SELECT parse (simple): <500ns + - Tier 1 SELECT parse (complex): <2us + - Query reconstruction: <500ns + - Arena reset: <10ns +3. Makefile.new `bench` target + +**Not in scope:** CI integration for benchmarks (too noisy in CI), optimization work. + +--- + +## File Structure + +``` +bench/ + bench_main.cpp — Google Benchmark main + bench_parser.cpp — All parser benchmarks + +Makefile.new — (modify) Add bench target +``` + +--- + +### Task 1: Benchmark Setup and All Benchmarks + +**Files:** +- Create: `bench/bench_main.cpp` +- Create: `bench/bench_parser.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Vendor Google Benchmark** + +```bash +git clone --depth 1 --branch v1.9.1 https://github.com/google/benchmark.git third_party/benchmark +``` + +- [ ] **Step 2: Create bench_main.cpp** + +Create `bench/bench_main.cpp`: +```cpp +#include + +BENCHMARK_MAIN(); +``` + +- [ ] **Step 3: Create bench_parser.cpp** + +Create `bench/bench_parser.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +// ========== Tier 2: Classification ========== +// Target: <100ns + +static void BM_Classify_Insert(benchmark::State& state) { + Parser parser; + const char* sql = "INSERT INTO users VALUES (1, 'name', 'email')"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Insert); + +static void BM_Classify_Update(benchmark::State& state) { + Parser parser; + const char* sql = "UPDATE users SET name = 'x' WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Update); + +static void BM_Classify_Delete(benchmark::State& state) { + Parser parser; + const char* sql = "DELETE FROM users WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Delete); + +static void BM_Classify_Begin(benchmark::State& state) { + Parser parser; + const char* sql = "BEGIN"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.stmt_type); + } +} +BENCHMARK(BM_Classify_Begin); + +// ========== Tier 1: SET parse ========== +// Target: <300ns + +static void BM_Set_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SET @@session.wait_timeout = 600"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_Simple); + +static void BM_Set_Names(benchmark::State& state) { + Parser parser; + const char* sql = "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_Names); + +static void BM_Set_MultiVar(benchmark::State& state) { + Parser parser; + const char* sql = "SET autocommit = 1, wait_timeout = 28800, sql_mode = 'STRICT_TRANS_TABLES'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_MultiVar); + +static void BM_Set_FunctionRHS(benchmark::State& state) { + Parser parser; + const char* sql = "SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Set_FunctionRHS); + +// ========== Tier 1: SELECT parse ========== +// Target: <500ns simple, <2us complex + +static void BM_Select_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT col FROM t WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Simple); + +static void BM_Select_MultiColumn(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT id, name, email, status FROM users WHERE active = 1 ORDER BY name LIMIT 100"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_MultiColumn); + +static void BM_Select_Join(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT u.id, o.total FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'active'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Join); + +static void BM_Select_Complex(benchmark::State& state) { + Parser parser; + const char* sql = + "SELECT u.id, u.name, COUNT(o.id) AS order_count " + "FROM users u " + "LEFT JOIN orders o ON u.id = o.user_id " + "WHERE u.status = 'active' AND u.created_at > '2024-01-01' " + "GROUP BY u.id, u.name " + "HAVING COUNT(o.id) > 5 " + "ORDER BY order_count DESC " + "LIMIT 50 OFFSET 10"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_Complex); + +static void BM_Select_MultiJoin(benchmark::State& state) { + Parser parser; + const char* sql = + "SELECT a.id, b.name, c.value, d.total " + "FROM t1 a " + "JOIN t2 b ON a.id = b.a_id " + "LEFT JOIN t3 c ON b.id = c.b_id " + "JOIN t4 d ON c.id = d.c_id " + "WHERE a.status = 1 AND d.total > 100 " + "ORDER BY d.total DESC " + "LIMIT 20"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_Select_MultiJoin); + +// ========== Query Reconstruction (round-trip) ========== +// Target: <500ns + +static void BM_Emit_SetSimple(benchmark::State& state) { + Parser parser; + const char* sql = "SET autocommit = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + benchmark::DoNotOptimize(emitter.result()); + } +} +BENCHMARK(BM_Emit_SetSimple); + +static void BM_Emit_SelectSimple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT * FROM users WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + benchmark::DoNotOptimize(emitter.result()); + } +} +BENCHMARK(BM_Emit_SelectSimple); + +// ========== Arena reset ========== +// Target: <10ns + +static void BM_ArenaReset(benchmark::State& state) { + Arena arena(65536); + for (auto _ : state) { + arena.allocate(256); // allocate something + arena.reset(); + benchmark::DoNotOptimize(arena.bytes_used()); + } +} +BENCHMARK(BM_ArenaReset); + +// ========== PostgreSQL ========== + +static void BM_PgSQL_Select_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SELECT col FROM t WHERE id = 1"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_PgSQL_Select_Simple); + +static void BM_PgSQL_Set_Simple(benchmark::State& state) { + Parser parser; + const char* sql = "SET work_mem = '256MB'"; + size_t len = strlen(sql); + for (auto _ : state) { + auto r = parser.parse(sql, len); + benchmark::DoNotOptimize(r.ast); + } +} +BENCHMARK(BM_PgSQL_Set_Simple); +``` + +- [ ] **Step 4: Update Makefile.new** + +Add benchmark build rules to `Makefile.new`: + +```makefile +# Google Benchmark +GBENCH_DIR = $(PROJECT_ROOT)/third_party/benchmark +### NOTE: After cloning Google Benchmark v1.9.1, verify the actual source files: +### ls third_party/benchmark/src/*.cc +### Then set GBENCH_SRCS to match. The following is a common set: +GBENCH_SRCS = $(wildcard $(GBENCH_DIR)/src/*.cc) +GBENCH_OBJS = $(GBENCH_SRCS:.cc=.o) +GBENCH_CPPFLAGS = -I$(GBENCH_DIR)/include -I$(GBENCH_DIR)/src -DHAVE_STD_REGEX -DHAVE_STEADY_CLOCK + +BENCH_DIR = $(PROJECT_ROOT)/bench +BENCH_SRCS = $(BENCH_DIR)/bench_main.cpp $(BENCH_DIR)/bench_parser.cpp +BENCH_OBJS = $(BENCH_SRCS:.cpp=.o) +BENCH_TARGET = $(PROJECT_ROOT)/run_bench + +# Benchmark objects +$(GBENCH_DIR)/src/%.o: $(GBENCH_DIR)/src/%.cc + $(CXX) $(CXXFLAGS) $(GBENCH_CPPFLAGS) -c $< -o $@ + +$(BENCH_DIR)/%.o: $(BENCH_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(GBENCH_CPPFLAGS) -c $< -o $@ + +bench: $(BENCH_TARGET) + ./$(BENCH_TARGET) --benchmark_format=console + +$(BENCH_TARGET): $(BENCH_OBJS) $(GBENCH_OBJS) $(LIB_TARGET) + $(CXX) $(CXXFLAGS) -o $@ $(BENCH_OBJS) $(GBENCH_OBJS) -L$(PROJECT_ROOT) -lsqlparser -lpthread +``` + +Add `bench` to `.PHONY` and update `clean` to include bench artifacts: +```makefile +.PHONY: all lib test bench clean +``` + +Add to clean: +```makefile + rm -f $(BENCH_OBJS) $(GBENCH_OBJS) $(BENCH_TARGET) +``` + +- [ ] **Step 5: Create bench directory and build** + +```bash +mkdir -p bench +make -f Makefile.new clean && make -f Makefile.new lib && make -f Makefile.new test && make -f Makefile.new bench +``` + +- [ ] **Step 6: Update .gitignore** + +Add to `.gitignore`: +``` +run_bench +``` + +- [ ] **Step 7: Commit** + +```bash +git add bench/ third_party/benchmark Makefile.new .gitignore +git commit -m "feat: add Google Benchmark performance tests for parser operations" +``` diff --git a/docs/superpowers/plans/2026-03-24-expression-and-set-parser.md b/docs/superpowers/plans/2026-03-24-expression-and-set-parser.md new file mode 100644 index 0000000..76706db --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-expression-and-set-parser.md @@ -0,0 +1,1416 @@ +# Expression Parser + SET Deep Parser Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the Pratt expression parser and full SET statement deep parser, upgrading SET from a Tier 2 stub to a Tier 1 parser that produces a complete AST for query reconstruction. + +**Architecture:** The expression parser is a standalone module using precedence climbing (Pratt parsing). It handles literals, identifiers, unary/binary operators, function calls, IS [NOT] NULL, BETWEEN, IN, CASE/WHEN, and subqueries. The SET parser uses the expression parser for right-hand-side values. Both are dialect-templated. After this plan, `parse_set()` returns `ParseResult::OK` with a full AST instead of `PARTIAL`. + +**Tech Stack:** C++17, existing arena/tokenizer/ast infrastructure from Plan 1 + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` + +--- + +## Scope + +This plan builds: +1. Pratt expression parser (`expression_parser.h`) — shared by SET and future SELECT parser +2. SET deep parser (`set_parser.h`) — full AST for all SET variants (MySQL + PostgreSQL) +3. Integration into `Parser` — `parse_set()` upgraded from stub to real parser +4. Tests for both expression parsing and SET parsing + +**Not in scope:** SELECT deep parser, emitter/reconstruction, prepared statement cache. + +--- + +## File Structure + +``` +include/sql_parser/ + expression_parser.h — Pratt expression parser (header-only template) + set_parser.h — SET statement parser (header-only template) + common.h — (modify) Add NODE_SCOPE_SPECIFIER if needed + +src/sql_parser/ + parser.cpp — (modify) Replace parse_set() stub with real implementation + +tests/ + test_expression.cpp — Expression parser unit tests + test_set.cpp — SET parser unit tests + +Makefile.new — (modify) Add new test files +``` + +--- + +### Task 1: Expression Parser — Literals and Identifiers + +**Files:** +- Create: `include/sql_parser/expression_parser.h` +- Create: `tests/test_expression.cpp` +- Modify: `Makefile.new` — add `test_expression.cpp` to TEST_SRCS + +- [ ] **Step 1: Write failing tests for literal and identifier parsing** + +Create `tests/test_expression.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/expression_parser.h" + +using namespace sql_parser; + +// Helper: parse an expression from a SQL string using a fresh parser context. +// We use the tokenizer directly since expression parsing is an internal function. +class ExpressionTest : public ::testing::Test { +protected: + Arena arena{4096}; + Tokenizer tok; + + AstNode* parse_expr(const char* sql) { + tok.reset(sql, strlen(sql)); + ExpressionParser ep(tok, arena); + return ep.parse(); + } +}; + +TEST_F(ExpressionTest, IntegerLiteral) { + AstNode* node = parse_expr("42"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "42"); +} + +TEST_F(ExpressionTest, FloatLiteral) { + AstNode* node = parse_expr("3.14"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_FLOAT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "3.14"); +} + +TEST_F(ExpressionTest, StringLiteral) { + AstNode* node = parse_expr("'hello'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_STRING); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "hello"); +} + +TEST_F(ExpressionTest, NullLiteral) { + AstNode* node = parse_expr("NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_NULL); +} + +TEST_F(ExpressionTest, TrueLiteral) { + AstNode* node = parse_expr("TRUE"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "TRUE"); +} + +TEST_F(ExpressionTest, FalseLiteral) { + AstNode* node = parse_expr("FALSE"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "FALSE"); +} + +TEST_F(ExpressionTest, SimpleIdentifier) { + AstNode* node = parse_expr("my_column"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_COLUMN_REF); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "my_column"); +} + +TEST_F(ExpressionTest, QualifiedIdentifier) { + AstNode* node = parse_expr("t.col"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_QUALIFIED_NAME); + // first child = table, second child = column + ASSERT_NE(node->first_child, nullptr); + ASSERT_NE(node->first_child->next_sibling, nullptr); +} + +TEST_F(ExpressionTest, Asterisk) { + AstNode* node = parse_expr("*"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_ASTERISK); +} + +TEST_F(ExpressionTest, Placeholder) { + AstNode* node = parse_expr("?"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_PLACEHOLDER); +} + +TEST_F(ExpressionTest, DefaultKeyword) { + AstNode* node = parse_expr("DEFAULT"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IDENTIFIER); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "DEFAULT"); +} + +TEST_F(ExpressionTest, UserVariable) { + AstNode* node = parse_expr("@my_var"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_COLUMN_REF); +} + +TEST_F(ExpressionTest, ParenthesizedExpression) { + AstNode* node = parse_expr("(42)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "42"); +} +``` + +- [ ] **Step 2: Create expression_parser.h with atom parsing** + +Create `include/sql_parser/expression_parser.h`: +```cpp +#ifndef SQL_PARSER_EXPRESSION_PARSER_H +#define SQL_PARSER_EXPRESSION_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" + +namespace sql_parser { + +// Operator precedence levels for Pratt parsing +enum class Precedence : uint8_t { + NONE = 0, + OR, // OR + AND, // AND + NOT, // NOT (prefix) + COMPARISON, // =, <, >, <=, >=, !=, <>, IS, LIKE, IN, BETWEEN + ADDITION, // +, - + MULTIPLICATION,// *, /, % + UNARY, // - (prefix), NOT + POSTFIX, // IS NULL, IS NOT NULL + CALL, // function() + PRIMARY, // literals, identifiers +}; + +template +class ExpressionParser { +public: + ExpressionParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena) {} + + // Parse an expression with minimum precedence 0 + AstNode* parse(Precedence min_prec = Precedence::NONE) { + AstNode* left = parse_atom(); + if (!left) return nullptr; + + while (true) { + Precedence prec = infix_precedence(tok_.peek().type); + if (prec <= min_prec) break; + + left = parse_infix(left, prec); + if (!left) return nullptr; + } + + return left; + } + +private: + Tokenizer& tok_; + Arena& arena_; + + // Parse a primary expression (atom) + AstNode* parse_atom() { + Token t = tok_.peek(); + + switch (t.type) { + case TokenType::TK_INTEGER: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_INT, t.text); + } + case TokenType::TK_FLOAT: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_FLOAT, t.text); + } + case TokenType::TK_STRING: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_STRING, t.text); + } + case TokenType::TK_NULL: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_NULL, t.text); + } + case TokenType::TK_TRUE: + case TokenType::TK_FALSE: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_INT, t.text); + } + case TokenType::TK_DEFAULT: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_IDENTIFIER, t.text); + } + case TokenType::TK_ASTERISK: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_ASTERISK, t.text); + } + case TokenType::TK_QUESTION: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_PLACEHOLDER, t.text); + } + case TokenType::TK_DOLLAR_NUM: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_PLACEHOLDER, t.text); + } + case TokenType::TK_AT: { + // User variable: @name + tok_.skip(); + Token name = tok_.next_token(); + // Build @name as a single COLUMN_REF with combined text + // value_ptr points to @ in original input, len covers @name + StringRef full{t.text.ptr, + static_cast((name.text.ptr + name.text.len) - t.text.ptr)}; + return make_node(arena_, NodeType::NODE_COLUMN_REF, full); + } + case TokenType::TK_DOUBLE_AT: { + // System variable: @@name or @@scope.name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{t.text.ptr, + static_cast((name.text.ptr + name.text.len) - t.text.ptr)}; + AstNode* node = make_node(arena_, NodeType::NODE_COLUMN_REF, full); + // Check for @@scope.name + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token var_name = tok_.next_token(); + full = StringRef{t.text.ptr, + static_cast((var_name.text.ptr + var_name.text.len) - t.text.ptr)}; + node->value_ptr = full.ptr; + node->value_len = full.len; + } + return node; + } + case TokenType::TK_MINUS: { + // Unary minus + tok_.skip(); + AstNode* operand = parse(Precedence::UNARY); + if (!operand) return nullptr; + AstNode* node = make_node(arena_, NodeType::NODE_UNARY_OP, t.text); + node->add_child(operand); + return node; + } + case TokenType::TK_PLUS: { + // Unary plus + tok_.skip(); + return parse(Precedence::UNARY); + } + case TokenType::TK_NOT: { + tok_.skip(); + // NOT IN, NOT BETWEEN, NOT LIKE are not unary prefix on the + // next atom — they modify the following infix operator. + // But NOT here is in prefix position (before the operand), + // so we parse the operand, then check if the next infix is + // IN/BETWEEN/LIKE and negate it. + AstNode* operand = parse(Precedence::NOT); + if (!operand) return nullptr; + AstNode* node = make_node(arena_, NodeType::NODE_UNARY_OP, t.text); + node->add_child(operand); + return node; + } + case TokenType::TK_CASE: { + tok_.skip(); + return parse_case(); + } + case TokenType::TK_LPAREN: { + tok_.skip(); + // Could be subquery: (SELECT ...) + if (tok_.peek().type == TokenType::TK_SELECT) { + // Subquery — for now, skip to matching paren + AstNode* node = make_node(arena_, NodeType::NODE_SUBQUERY); + skip_to_matching_paren(); + return node; + } + AstNode* expr = parse(); + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); + } + return expr; + } + case TokenType::TK_IDENTIFIER: { + tok_.skip(); + return parse_identifier_or_function(t); + } + // Keywords that can appear as identifiers in expression context + // (e.g., column names that happen to be keywords) + default: { + if (is_keyword_as_identifier(t.type)) { + tok_.skip(); + return parse_identifier_or_function(t); + } + return nullptr; // not an expression + } + } + } + + AstNode* parse_identifier_or_function(const Token& name_token) { + // Check for function call: name( + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); // consume ( + AstNode* func = make_node(arena_, NodeType::NODE_FUNCTION_CALL, name_token.text); + // Parse argument list + if (tok_.peek().type != TokenType::TK_RPAREN) { + while (true) { + AstNode* arg = parse(); + if (arg) func->add_child(arg); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); + } + return func; + } + + // Check for qualified name: table.column + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); // consume dot + Token col = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name_token.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + return qname; + } + + return make_node(arena_, NodeType::NODE_COLUMN_REF, name_token.text); + } + + // Infix precedence for a token type. + // Returns NONE if not an infix operator (stops the Pratt loop). + // NOT is handled here as a pseudo-infix: NOT IN, NOT BETWEEN, NOT LIKE + // are compound operators at COMPARISON precedence. + static Precedence infix_precedence(TokenType type) { + switch (type) { + case TokenType::TK_OR: return Precedence::OR; + case TokenType::TK_AND: return Precedence::AND; + case TokenType::TK_NOT: return Precedence::COMPARISON; // NOT IN/BETWEEN/LIKE + case TokenType::TK_EQUAL: + case TokenType::TK_NOT_EQUAL: + case TokenType::TK_LESS: + case TokenType::TK_GREATER: + case TokenType::TK_LESS_EQUAL: + case TokenType::TK_GREATER_EQUAL: + case TokenType::TK_LIKE: return Precedence::COMPARISON; + case TokenType::TK_IS: return Precedence::COMPARISON; + case TokenType::TK_IN: return Precedence::COMPARISON; + case TokenType::TK_BETWEEN: return Precedence::COMPARISON; + case TokenType::TK_PLUS: + case TokenType::TK_MINUS: return Precedence::ADDITION; + case TokenType::TK_ASTERISK: + case TokenType::TK_SLASH: + case TokenType::TK_PERCENT: return Precedence::MULTIPLICATION; + case TokenType::TK_DOUBLE_PIPE: return Precedence::ADDITION; // string concat + default: return Precedence::NONE; + } + } + + AstNode* parse_infix(AstNode* left, Precedence prec) { + Token op = tok_.next_token(); + + switch (op.type) { + case TokenType::TK_NOT: { + // NOT IN / NOT BETWEEN / NOT LIKE — compound negated infix + Token actual_op = tok_.peek(); + if (actual_op.type == TokenType::TK_IN) { + tok_.skip(); + AstNode* in_node = parse_in(left); + // Wrap in NOT + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(in_node); + return not_node; + } + if (actual_op.type == TokenType::TK_BETWEEN) { + tok_.skip(); + AstNode* between_node = parse_between(left); + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(between_node); + return not_node; + } + if (actual_op.type == TokenType::TK_LIKE) { + tok_.skip(); + AstNode* right = parse(prec); + AstNode* like_node = make_node(arena_, NodeType::NODE_BINARY_OP, actual_op.text); + like_node->add_child(left); + if (right) like_node->add_child(right); + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(like_node); + return not_node; + } + // Standalone NOT in infix position — shouldn't happen, return left + return left; + } + case TokenType::TK_IS: { + // IS [NOT] NULL + bool is_not = false; + if (tok_.peek().type == TokenType::TK_NOT) { + is_not = true; + tok_.skip(); + } + if (tok_.peek().type == TokenType::TK_NULL) { + tok_.skip(); + NodeType nt = is_not ? NodeType::NODE_IS_NOT_NULL : NodeType::NODE_IS_NULL; + AstNode* node = make_node(arena_, nt); + node->add_child(left); + return node; + } + // IS TRUE / IS FALSE / IS NOT TRUE / IS NOT FALSE + if (tok_.peek().type == TokenType::TK_TRUE || tok_.peek().type == TokenType::TK_FALSE) { + Token val = tok_.next_token(); + AstNode* node = make_node(arena_, NodeType::NODE_BINARY_OP, + is_not ? StringRef{"IS NOT", 6} : StringRef{"IS", 2}); + node->add_child(left); + node->add_child(make_node(arena_, NodeType::NODE_LITERAL_INT, val.text)); + return node; + } + return left; + } + case TokenType::TK_IN: + return parse_in(left); + case TokenType::TK_BETWEEN: + return parse_between(left); + default: { + // Standard binary operator + AstNode* right = parse(prec); + if (!right) return left; + AstNode* node = make_node(arena_, NodeType::NODE_BINARY_OP, op.text); + node->add_child(left); + node->add_child(right); + return node; + } + } + } + + // IN (value_list) or IN (subquery) + AstNode* parse_in(AstNode* left) { + AstNode* node = make_node(arena_, NodeType::NODE_IN_LIST); + node->add_child(left); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_SELECT) { + AstNode* sq = make_node(arena_, NodeType::NODE_SUBQUERY); + skip_to_matching_paren(); + node->add_child(sq); + } else { + while (true) { + AstNode* val = parse(); + if (val) node->add_child(val); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + } + } + return node; + } + + // BETWEEN low AND high + AstNode* parse_between(AstNode* left) { + AstNode* node = make_node(arena_, NodeType::NODE_BETWEEN); + node->add_child(left); + AstNode* low = parse(Precedence::COMPARISON); + node->add_child(low); + if (tok_.peek().type == TokenType::TK_AND) { + tok_.skip(); + } + AstNode* high = parse(Precedence::COMPARISON); + node->add_child(high); + return node; + } + + // CASE [expr] WHEN ... THEN ... [ELSE ...] END + AstNode* parse_case() { + AstNode* node = make_node(arena_, NodeType::NODE_CASE_WHEN); + // Optional simple CASE expression: CASE expr WHEN ... + if (tok_.peek().type != TokenType::TK_WHEN) { + AstNode* case_expr = parse(); + if (case_expr) node->add_child(case_expr); + } + // WHEN ... THEN ... pairs + while (tok_.peek().type == TokenType::TK_WHEN) { + tok_.skip(); + AstNode* when_expr = parse(); + if (when_expr) node->add_child(when_expr); + if (tok_.peek().type == TokenType::TK_THEN) tok_.skip(); + AstNode* then_expr = parse(); + if (then_expr) node->add_child(then_expr); + } + // Optional ELSE + if (tok_.peek().type == TokenType::TK_ELSE) { + tok_.skip(); + AstNode* else_expr = parse(); + if (else_expr) node->add_child(else_expr); + } + // END + if (tok_.peek().type == TokenType::TK_END) tok_.skip(); + return node; + } + + // Skip tokens until matching closing paren (handles nesting) + void skip_to_matching_paren() { + int depth = 1; + while (depth > 0) { + Token t = tok_.next_token(); + if (t.type == TokenType::TK_LPAREN) ++depth; + else if (t.type == TokenType::TK_RPAREN) --depth; + else if (t.type == TokenType::TK_EOF) break; + } + } + + // Some keywords can appear as identifiers in expression context + static bool is_keyword_as_identifier(TokenType type) { + switch (type) { + // Keywords commonly used as column/table names + case TokenType::TK_COUNT: + case TokenType::TK_SUM: + case TokenType::TK_AVG: + case TokenType::TK_MIN: + case TokenType::TK_MAX: + case TokenType::TK_IF: + case TokenType::TK_VALUES: + case TokenType::TK_DATABASE: + case TokenType::TK_SCHEMA: + case TokenType::TK_TABLE: + case TokenType::TK_INDEX: + case TokenType::TK_VIEW: + case TokenType::TK_NAMES: + case TokenType::TK_CHARACTER: + case TokenType::TK_CHARSET: + case TokenType::TK_GLOBAL: + case TokenType::TK_SESSION: + case TokenType::TK_LOCAL: + case TokenType::TK_LEVEL: + case TokenType::TK_READ: + case TokenType::TK_WRITE: + case TokenType::TK_ONLY: + case TokenType::TK_TRANSACTION: + case TokenType::TK_ISOLATION: + case TokenType::TK_COMMITTED: + case TokenType::TK_UNCOMMITTED: + case TokenType::TK_REPEATABLE: + case TokenType::TK_SERIALIZABLE: + case TokenType::TK_SHARE: + case TokenType::TK_DATA: + case TokenType::TK_RESET: + return true; + default: + return false; + } + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_EXPRESSION_PARSER_H +``` + +- [ ] **Step 3: Add test_expression.cpp to Makefile.new** + +In `Makefile.new`, add `$(TEST_DIR)/test_expression.cpp` and `$(TEST_DIR)/test_set.cpp` to `TEST_SRCS`: + +Change the `TEST_SRCS` line to: +```makefile +TEST_SRCS = $(TEST_DIR)/test_main.cpp \ + $(TEST_DIR)/test_arena.cpp \ + $(TEST_DIR)/test_tokenizer.cpp \ + $(TEST_DIR)/test_classifier.cpp \ + $(TEST_DIR)/test_expression.cpp \ + $(TEST_DIR)/test_set.cpp +``` + +Also create an empty `tests/test_set.cpp` placeholder so the build doesn't break: +```cpp +#include +// SET parser tests will be added in Task 3 +``` + +- [ ] **Step 4: Build and run tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All existing tests pass + new expression tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add include/sql_parser/expression_parser.h tests/test_expression.cpp tests/test_set.cpp Makefile.new +git commit -m "feat: add Pratt expression parser with literals, identifiers, and operators" +``` + +--- + +### Task 2: Expression Parser — Binary Operators, IS NULL, BETWEEN, IN, Functions + +**Files:** +- Modify: `tests/test_expression.cpp` — add operator and complex expression tests + +- [ ] **Step 1: Add binary operator and complex expression tests** + +Append to `tests/test_expression.cpp`: +```cpp +TEST_F(ExpressionTest, BinaryAdd) { + AstNode* node = parse_expr("1 + 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "+"); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_LITERAL_INT); + ASSERT_NE(node->first_child->next_sibling, nullptr); + EXPECT_EQ(node->first_child->next_sibling->type, NodeType::NODE_LITERAL_INT); +} + +TEST_F(ExpressionTest, Precedence_MulOverAdd) { + // 1 + 2 * 3 should parse as 1 + (2 * 3) + AstNode* node = parse_expr("1 + 2 * 3"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "+"); + // Right child should be 2*3 + AstNode* right = node->first_child->next_sibling; + ASSERT_NE(right, nullptr); + EXPECT_EQ(right->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(right->value_ptr, right->value_len), "*"); +} + +TEST_F(ExpressionTest, ComparisonEqual) { + AstNode* node = parse_expr("x = 1"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "="); +} + +TEST_F(ExpressionTest, LogicalAnd) { + AstNode* node = parse_expr("a = 1 AND b = 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "AND"); +} + +TEST_F(ExpressionTest, LogicalOr) { + AstNode* node = parse_expr("a = 1 OR b = 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "OR"); +} + +TEST_F(ExpressionTest, UnaryMinus) { + AstNode* node = parse_expr("-42"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_LITERAL_INT); +} + +TEST_F(ExpressionTest, UnaryNot) { + AstNode* node = parse_expr("NOT x"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); +} + +TEST_F(ExpressionTest, IsNull) { + AstNode* node = parse_expr("x IS NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IS_NULL); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_COLUMN_REF); +} + +TEST_F(ExpressionTest, IsNotNull) { + AstNode* node = parse_expr("x IS NOT NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IS_NOT_NULL); +} + +TEST_F(ExpressionTest, Between) { + AstNode* node = parse_expr("x BETWEEN 1 AND 10"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BETWEEN); + // 3 children: expr, low, high + ASSERT_NE(node->first_child, nullptr); + ASSERT_NE(node->first_child->next_sibling, nullptr); + ASSERT_NE(node->first_child->next_sibling->next_sibling, nullptr); +} + +TEST_F(ExpressionTest, InList) { + AstNode* node = parse_expr("x IN (1, 2, 3)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IN_LIST); + // Children: expr, val1, val2, val3 = 4 children + int count = 0; + for (AstNode* c = node->first_child; c; c = c->next_sibling) ++count; + EXPECT_EQ(count, 4); +} + +TEST_F(ExpressionTest, FunctionCall) { + AstNode* node = parse_expr("COUNT(*)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "COUNT"); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_ASTERISK); +} + +TEST_F(ExpressionTest, FunctionCallMultiArg) { + AstNode* node = parse_expr("COALESCE(a, b, 0)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + int count = 0; + for (AstNode* c = node->first_child; c; c = c->next_sibling) ++count; + EXPECT_EQ(count, 3); +} + +TEST_F(ExpressionTest, NestedParens) { + AstNode* node = parse_expr("(1 + 2) * 3"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "*"); + // Left child should be 1+2 + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->first_child->value_ptr, node->first_child->value_len), "+"); +} + +TEST_F(ExpressionTest, LikeOperator) { + AstNode* node = parse_expr("name LIKE '%test%'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "LIKE"); +} + +TEST_F(ExpressionTest, StringConcat) { + AstNode* node = parse_expr("a || b"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); +} + +TEST_F(ExpressionTest, NotIn) { + AstNode* node = parse_expr("x NOT IN (1, 2)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps IN_LIST + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_IN_LIST); +} + +TEST_F(ExpressionTest, NotBetween) { + AstNode* node = parse_expr("x NOT BETWEEN 1 AND 10"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps BETWEEN + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BETWEEN); +} + +TEST_F(ExpressionTest, NotLike) { + AstNode* node = parse_expr("name NOT LIKE '%test'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps LIKE + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BINARY_OP); +} + +TEST_F(ExpressionTest, CaseWhenSimple) { + AstNode* node = parse_expr("CASE WHEN x = 1 THEN 'a' ELSE 'b' END"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_CASE_WHEN); +} + +TEST_F(ExpressionTest, CaseWhenSearched) { + AstNode* node = parse_expr("CASE x WHEN 1 THEN 'a' WHEN 2 THEN 'b' END"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_CASE_WHEN); +} + +TEST_F(ExpressionTest, ZeroArgFunction) { + AstNode* node = parse_expr("NOW()"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + EXPECT_EQ(node->first_child, nullptr); // no args +} +``` + +- [ ] **Step 2: Build and run all tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_expression.cpp +git commit -m "test: add operator, IS NULL, BETWEEN, IN, and function call expression tests" +``` + +--- + +### Task 3: SET Deep Parser + +**Files:** +- Create: `include/sql_parser/set_parser.h` +- Modify: `tests/test_set.cpp` — add SET tests +- Modify: `src/sql_parser/parser.cpp` — replace `parse_set()` stub +- Modify: `include/sql_parser/parser.h` — add SET parser include and method declarations + +- [ ] **Step 1: Write SET parser tests** + +Replace `tests/test_set.cpp` with: +```cpp +#include +#include "sql_parser/parser.h" + +using namespace sql_parser; + +class MySQLSetTest : public ::testing::Test { +protected: + Parser parser; + + // Helper to count children of a node + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } +}; + +TEST_F(MySQLSetTest, SetSimpleVariable) { + auto r = parser.parse("SET autocommit = 1", 18); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SET_STMT); + // One assignment child + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_VAR_ASSIGNMENT); +} + +TEST_F(MySQLSetTest, SetMultipleVariables) { + auto r = parser.parse("SET autocommit = 1, wait_timeout = 28800", 41); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(child_count(r.ast), 2); +} + +TEST_F(MySQLSetTest, SetGlobalVariable) { + auto r = parser.parse("SET GLOBAL max_connections = 100", 31); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + // First child of assignment is the var target + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + EXPECT_EQ(target->type, NodeType::NODE_VAR_TARGET); +} + +TEST_F(MySQLSetTest, SetSessionVariable) { + auto r = parser.parse("SET SESSION wait_timeout = 600", 30); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetDoubleAtVariable) { + auto r = parser.parse("SET @@session.wait_timeout = 600", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetUserVariable) { + auto r = parser.parse("SET @my_var = 42", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetNames) { + auto r = parser.parse("SET NAMES utf8mb4", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SET_STMT); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_NAMES); +} + +TEST_F(MySQLSetTest, SetNamesCollate) { + auto r = parser.parse("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci", 44); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_NAMES); + // Should have 2 children: charset and collation + EXPECT_EQ(child_count(r.ast->first_child), 2); +} + +TEST_F(MySQLSetTest, SetCharacterSet) { + auto r = parser.parse("SET CHARACTER SET utf8", 21); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_CHARSET); +} + +TEST_F(MySQLSetTest, SetCharset) { + auto r = parser.parse("SET CHARSET utf8", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_CHARSET); +} + +TEST_F(MySQLSetTest, SetTransaction) { + auto r = parser.parse("SET TRANSACTION READ ONLY", 25); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_TRANSACTION); +} + +TEST_F(MySQLSetTest, SetTransactionIsolation) { + auto r = parser.parse("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetGlobalTransaction) { + auto r = parser.parse("SET GLOBAL TRANSACTION READ WRITE", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetExpressionRHS) { + auto r = parser.parse("SET @x = 1 + 2", 14); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetColonEqual) { + auto r = parser.parse("SET @x := 42", 12); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetNamesDefault) { + auto r = parser.parse("SET NAMES DEFAULT", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetWithSemicolon) { + const char* sql = "SET autocommit = 0; BEGIN"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + EXPECT_TRUE(r.has_remaining()); +} + +// ========== PostgreSQL SET Tests ========== + +class PgSQLSetTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLSetTest, SetVarToValue) { + auto r = parser.parse("SET client_encoding TO 'UTF8'", 29); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetVarEqualValue) { + auto r = parser.parse("SET work_mem = '256MB'", 22); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetLocalVar) { + auto r = parser.parse("SET LOCAL timezone = 'UTC'", 25); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetNamesPostgres) { + auto r = parser.parse("SET NAMES 'UTF8'", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} +``` + +- [ ] **Step 2: Write set_parser.h** + +Create `include/sql_parser/set_parser.h`: +```cpp +#ifndef SQL_PARSER_SET_PARSER_H +#define SQL_PARSER_SET_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class SetParser { +public: + SetParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {} + + // Parse a SET statement (SET keyword already consumed by classifier). + // Returns the root NODE_SET_STMT node, or nullptr on failure. + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_SET_STMT); + if (!root) return nullptr; + + Token next = tok_.peek(); + + // SET NAMES ... + if (next.type == TokenType::TK_NAMES) { + tok_.skip(); + AstNode* names_node = parse_set_names(); + if (names_node) root->add_child(names_node); + return root; + } + + // SET CHARACTER SET ... or SET CHARSET ... + if (next.type == TokenType::TK_CHARACTER) { + tok_.skip(); + // Expect SET keyword + if (tok_.peek().type == TokenType::TK_SET) { + tok_.skip(); + } + AstNode* charset_node = parse_set_charset(); + if (charset_node) root->add_child(charset_node); + return root; + } + if (next.type == TokenType::TK_CHARSET) { + tok_.skip(); + AstNode* charset_node = parse_set_charset(); + if (charset_node) root->add_child(charset_node); + return root; + } + + // SET [GLOBAL|SESSION] TRANSACTION ... + // Need to check for scope + TRANSACTION or just TRANSACTION + if (next.type == TokenType::TK_TRANSACTION) { + tok_.skip(); + AstNode* txn_node = parse_set_transaction(StringRef{}); + if (txn_node) root->add_child(txn_node); + return root; + } + + if (next.type == TokenType::TK_GLOBAL || next.type == TokenType::TK_SESSION) { + Token scope_tok = tok_.next_token(); + if (tok_.peek().type == TokenType::TK_TRANSACTION) { + tok_.skip(); + AstNode* txn_node = parse_set_transaction(scope_tok.text); + if (txn_node) root->add_child(txn_node); + return root; + } + // Not TRANSACTION — it's SET GLOBAL var = expr + // Fall through to variable assignment with scope + AstNode* assignment = parse_variable_assignment(&scope_tok); + if (assignment) root->add_child(assignment); + // Parse remaining comma-separated assignments + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_assign = parse_variable_assignment(nullptr); + if (next_assign) root->add_child(next_assign); + } + return root; + } + + // PostgreSQL: SET LOCAL var = expr + if constexpr (D == Dialect::PostgreSQL) { + if (next.type == TokenType::TK_LOCAL) { + Token scope_tok = tok_.next_token(); + AstNode* assignment = parse_variable_assignment(&scope_tok); + if (assignment) root->add_child(assignment); + return root; + } + } + + // SET var = expr [, var = expr, ...] + AstNode* assignment = parse_variable_assignment(nullptr); + if (assignment) root->add_child(assignment); + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_assign = parse_variable_assignment(nullptr); + if (next_assign) root->add_child(next_assign); + } + + return root; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + + // SET NAMES charset [COLLATE collation] + AstNode* parse_set_names() { + AstNode* node = make_node(arena_, NodeType::NODE_SET_NAMES); + if (!node) return nullptr; + + // charset name or DEFAULT + Token charset = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text)); + + // Optional COLLATE + if (tok_.peek().type == TokenType::TK_COLLATE) { + tok_.skip(); + Token collation = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, collation.text)); + } + return node; + } + + // SET CHARACTER SET charset / SET CHARSET charset + AstNode* parse_set_charset() { + AstNode* node = make_node(arena_, NodeType::NODE_SET_CHARSET); + if (!node) return nullptr; + + Token charset = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text)); + return node; + } + + // SET [GLOBAL|SESSION] TRANSACTION ... + AstNode* parse_set_transaction(StringRef scope) { + AstNode* node = make_node(arena_, NodeType::NODE_SET_TRANSACTION); + if (!node) return nullptr; + + if (!scope.empty()) { + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope)); + } + + // ISOLATION LEVEL ... or READ ONLY/WRITE + Token next = tok_.peek(); + if (next.type == TokenType::TK_ISOLATION) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LEVEL) tok_.skip(); + + // READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE + Token level = tok_.next_token(); + if (level.type == TokenType::TK_READ) { + Token sublevel = tok_.next_token(); + // Combine "READ COMMITTED" or "READ UNCOMMITTED" + StringRef combined{level.text.ptr, + static_cast((sublevel.text.ptr + sublevel.text.len) - level.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } else if (level.type == TokenType::TK_REPEATABLE) { + Token read_tok = tok_.next_token(); // READ + StringRef combined{level.text.ptr, + static_cast((read_tok.text.ptr + read_tok.text.len) - level.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } else { + // SERIALIZABLE + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, level.text)); + } + } else if (next.type == TokenType::TK_READ) { + tok_.skip(); + Token rw = tok_.next_token(); // ONLY or WRITE + StringRef combined{next.text.ptr, + static_cast((rw.text.ptr + rw.text.len) - next.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } + + return node; + } + + // Parse a single variable assignment: [scope] target = expr + // scope_token is non-null if GLOBAL/SESSION/LOCAL was already consumed + AstNode* parse_variable_assignment(const Token* scope_token) { + AstNode* assignment = make_node(arena_, NodeType::NODE_VAR_ASSIGNMENT); + if (!assignment) return nullptr; + + // Build the variable target + AstNode* target = make_node(arena_, NodeType::NODE_VAR_TARGET); + if (!target) return nullptr; + + if (scope_token) { + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope_token->text)); + } + + Token var = tok_.peek(); + if (var.type == TokenType::TK_AT) { + // User variable @name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{var.text.ptr, + static_cast((name.text.ptr + name.text.len) - var.text.ptr)}; + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full)); + } else if (var.type == TokenType::TK_DOUBLE_AT) { + // System variable @@[scope.]name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{var.text.ptr, + static_cast((name.text.ptr + name.text.len) - var.text.ptr)}; + // Check for @@scope.name + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token actual_name = tok_.next_token(); + full = StringRef{var.text.ptr, + static_cast((actual_name.text.ptr + actual_name.text.len) - var.text.ptr)}; + } + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full)); + } else { + // Plain variable name + Token name = tok_.next_token(); + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + } + + assignment->add_child(target); + + // Expect = or := (MySQL) or TO (PostgreSQL) + Token eq = tok_.peek(); + if (eq.type == TokenType::TK_EQUAL || eq.type == TokenType::TK_COLON_EQUAL) { + tok_.skip(); + } else if constexpr (D == Dialect::PostgreSQL) { + if (eq.type == TokenType::TK_TO) { + tok_.skip(); + } + } + + // Parse RHS expression + AstNode* rhs = expr_parser_.parse(); + if (rhs) assignment->add_child(rhs); + + return assignment; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_SET_PARSER_H +``` + +- [ ] **Step 3: Integrate SET parser into Parser class** + +Modify `src/sql_parser/parser.cpp` — add includes for set_parser.h and expression_parser.h at the top of the file (these are implementation details, not public API, so they belong in the .cpp not the .h): + +Add after `#include "sql_parser/parser.h"`: +```cpp +#include "sql_parser/expression_parser.h" +#include "sql_parser/set_parser.h" +``` + +Then replace the `parse_set()` stub: + +Replace: +```cpp +template +ParseResult Parser::parse_set() { + ParseResult r; + r.status = ParseResult::PARTIAL; + r.stmt_type = StmtType::SET; + scan_to_end(r); + return r; +} +``` + +With: +```cpp +template +ParseResult Parser::parse_set() { + ParseResult r; + r.stmt_type = StmtType::SET; + + SetParser set_parser(tokenizer_, arena_); + AstNode* ast = set_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 4: Build and run all tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All tests pass including new SET tests. + +- [ ] **Step 5: Commit** + +```bash +git add include/sql_parser/set_parser.h include/sql_parser/parser.h src/sql_parser/parser.cpp tests/test_set.cpp +git commit -m "feat: add SET deep parser with full AST for all SET variants" +``` + +--- + +### Task 4: Verify existing classifier tests still pass and clean up + +**Files:** +- No new files — verification and cleanup only + +- [ ] **Step 1: Run full test suite** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: ALL tests pass, zero warnings. + +- [ ] **Step 2: Verify SET classifier tests now return OK instead of PARTIAL** + +The existing `ClassifySet` test in `test_classifier.cpp` checked for `stmt_type == StmtType::SET` but did not check `status`. The SET parser now returns `OK` instead of `PARTIAL`. Verify this doesn't break anything: + +Run: +```bash +./run_tests --gtest_filter="*Set*" +``` + +- [ ] **Step 3: Check for compiler warnings** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all 2>&1 | grep -i warning +``` +Expected: Zero warnings (or only from Google Test internals). + +- [ ] **Step 4: Commit if any fixes were needed** + +```bash +# Only if changes were made +git add -A && git commit -m "fix: clean up warnings and test compatibility after SET parser integration" +``` + +--- + +## What's Next + +After this plan is complete, the following plans remain: + +1. **Plan 3: SELECT Deep Parser** — Full SELECT parsing with FROM, JOIN, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, locking. Uses the expression parser from this plan. +2. **Plan 4: Query Emitter** — AST → SQL reconstruction. +3. **Plan 5: Prepared Statement Cache** — Binary protocol support. +4. **Plan 6: Performance Benchmarks** — Validate latency targets. diff --git a/docs/superpowers/plans/2026-03-24-plan10-compound-queries.md b/docs/superpowers/plans/2026-03-24-plan10-compound-queries.md new file mode 100644 index 0000000..9aab922 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-plan10-compound-queries.md @@ -0,0 +1,628 @@ +# Compound Query (UNION/INTERSECT/EXCEPT) Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build a `CompoundQueryParser` that handles UNION [ALL], INTERSECT [ALL], and EXCEPT [ALL] with correct precedence (INTERSECT binds tighter than UNION/EXCEPT), parenthesized nesting, and trailing ORDER BY/LIMIT on compound results. + +**Architecture:** `CompoundQueryParser` is a separate layer above `SelectParser`. It uses Pratt-style precedence parsing: INTERSECT has higher precedence than UNION/EXCEPT (which are equal). Each operand is either a parenthesized compound (recursive) or a single SELECT via `SelectParser`. If no set operator follows the first SELECT, the compound parser returns the bare `NODE_SELECT_STMT` as-is with zero overhead. The parser is integrated by updating `Parser::parse_select()` to call `CompoundQueryParser` instead of `SelectParser` directly. + +**Tech Stack:** C++17, existing parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` + +--- + +## Scope + +This plan builds: +1. New tokens: `TK_INTERSECT`, `TK_EXCEPT` +2. New node types: `NODE_COMPOUND_QUERY`, `NODE_SET_OPERATION` +3. `CompoundQueryParser` — header-only template for compound query parsing +4. Emitter extensions for compound query nodes +5. Integration into `parse_select()` flow +6. Comprehensive tests with precedence verification + +**Closes:** #8 + +**Dependencies:** None strictly required, but Plan 7's `is_alias_start()` update (adding TK_INTERSECT, TK_EXCEPT to blocklist) is helpful. If Plan 7 is not done first, this plan must do that update itself. + +--- + +## File Structure + +``` +include/sql_parser/ + compound_query_parser.h — (create) UNION/INTERSECT/EXCEPT parser + emitter.h — (modify) add compound query emit methods + common.h — (modify) add NODE_COMPOUND_QUERY, NODE_SET_OPERATION + token.h — (modify) add TK_INTERSECT, TK_EXCEPT + select_parser.h — (modify) update is_alias_start blocklist (if not done by Plan 7) + +src/sql_parser/ + parser.cpp — (modify) update parse_select() to use CompoundQueryParser + +include/sql_parser/ + keywords_mysql.h — (modify) add INTERSECT, EXCEPT keywords + keywords_pgsql.h — (modify) add INTERSECT, EXCEPT keywords + +tests/ + test_compound.cpp — (create) compound query tests + +Makefile.new — (modify) add test_compound.cpp to TEST_SRCS +``` + +--- + +### Task 1: Add New Tokens and Node Types + +**Files:** +- Modify: `include/sql_parser/token.h` +- Modify: `include/sql_parser/common.h` +- Modify: `include/sql_parser/keywords_mysql.h` +- Modify: `include/sql_parser/keywords_pgsql.h` + +- [ ] **Step 1: Add new token types** + +Add to `TokenType` enum: +```cpp +TK_INTERSECT, +TK_EXCEPT, +``` + +Note: `TK_UNION` already exists. + +- [ ] **Step 2: Add new node types** + +Add to `NodeType` enum: +```cpp +// Compound query nodes +NODE_COMPOUND_QUERY, // root for UNION/INTERSECT/EXCEPT +NODE_SET_OPERATION, // operator (UNION, INTERSECT, EXCEPT) with ALL flag +``` + +- [ ] **Step 3: Register keywords** + +Add `INTERSECT` and `EXCEPT` to both `keywords_mysql.h` and `keywords_pgsql.h`. + +- [ ] **Step 4: Update `is_alias_start()` blocklist** + +In `TableRefParser::is_alias_start()` (or `SelectParser` if Plan 7 is not yet applied), add: +```cpp +case TokenType::TK_INTERSECT: +case TokenType::TK_EXCEPT: +``` + +These keywords start compound operators and must not be misinterpreted as implicit aliases. + +--- + +### Task 2: CompoundQueryParser Implementation + +**Files:** +- Create: `include/sql_parser/compound_query_parser.h` +- Create: `tests/test_compound.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write tests for compound queries** + +Create `tests/test_compound.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLCompoundTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Simple SELECT (no compound) ========== +// CompoundQueryParser must return bare NODE_SELECT_STMT when no set operator follows + +TEST_F(MySQLCompoundTest, PlainSelectUnchanged) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + // Should be NODE_SELECT_STMT, NOT NODE_COMPOUND_QUERY + EXPECT_EQ(r.ast->type, NodeType::NODE_SELECT_STMT); +} + +// ========== UNION ========== + +TEST_F(MySQLCompoundTest, SimpleUnion) { + const char* sql = "SELECT 1 UNION SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + auto* setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(setop, nullptr); +} + +TEST_F(MySQLCompoundTest, UnionAll) { + const char* sql = "SELECT 1 UNION ALL SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, UnionThreeSelects) { + const char* sql = "SELECT 1 UNION SELECT 2 UNION SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, UnionWithOrderBy) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); +} + +TEST_F(MySQLCompoundTest, UnionWithLimit) { + const char* sql = "SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLCompoundTest, UnionWithOrderByAndLimit) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== INTERSECT ========== + +TEST_F(MySQLCompoundTest, SimpleIntersect) { + const char* sql = "SELECT 1 INTERSECT SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, IntersectAll) { + const char* sql = "SELECT 1 INTERSECT ALL SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== EXCEPT ========== + +TEST_F(MySQLCompoundTest, SimpleExcept) { + const char* sql = "SELECT 1 EXCEPT SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, ExceptAll) { + const char* sql = "SELECT 1 EXCEPT ALL SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Precedence: INTERSECT > UNION/EXCEPT ========== + +TEST_F(MySQLCompoundTest, IntersectBindsTighterThanUnion) { + // SELECT 1 UNION SELECT 2 INTERSECT SELECT 3 + // Should parse as: SELECT 1 UNION (SELECT 2 INTERSECT SELECT 3) + const char* sql = "SELECT 1 UNION SELECT 2 INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + + // The top-level set operation should be UNION + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + // The value should contain "UNION" + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("UNION", 5)); + + // The right child of UNION should be a SET_OPERATION (INTERSECT) + const AstNode* left = top_setop->first_child; + ASSERT_NE(left, nullptr); + const AstNode* right = left->next_sibling; + ASSERT_NE(right, nullptr); + EXPECT_EQ(right->type, NodeType::NODE_SET_OPERATION); + StringRef right_op = right->value(); + EXPECT_TRUE(right_op.equals_ci("INTERSECT", 9)); +} + +TEST_F(MySQLCompoundTest, IntersectBindsTighterThanExcept) { + // SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3 + // Should parse as: SELECT 1 EXCEPT (SELECT 2 INTERSECT SELECT 3) + const char* sql = "SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("EXCEPT", 6)); +} + +// ========== Parenthesized nesting ========== + +TEST_F(MySQLCompoundTest, ParenthesizedUnion) { + const char* sql = "(SELECT 1) UNION (SELECT 2)"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, ParenthesizedOverridesPrecedence) { + // (SELECT 1 UNION SELECT 2) INTERSECT SELECT 3 + // Parentheses force UNION to be evaluated first + const char* sql = "(SELECT 1 UNION SELECT 2) INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("INTERSECT", 9)); +} + +// ========== Complex compound queries ========== + +TEST_F(MySQLCompoundTest, UnionWithFullSelects) { + const char* sql = "SELECT a, b FROM t1 WHERE x = 1 UNION ALL SELECT a, b FROM t2 WHERE y = 2 ORDER BY a LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +// ========== PostgreSQL compound queries ========== + +class PgSQLCompoundTest : public ::testing::Test { +protected: + Parser parser; + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLCompoundTest, SimpleUnion) { + const char* sql = "SELECT 1 UNION SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(PgSQLCompoundTest, IntersectExcept) { + const char* sql = "SELECT 1 INTERSECT SELECT 2 EXCEPT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLCompoundTest, UnionReturnsCorrectDialect) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct CompoundTestCase { + const char* sql; + const char* description; +}; + +static const CompoundTestCase compound_bulk_cases[] = { + {"SELECT 1 UNION SELECT 2", "simple union"}, + {"SELECT 1 UNION ALL SELECT 2", "union all"}, + {"SELECT 1 UNION SELECT 2 UNION SELECT 3", "triple union"}, + {"SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3", "triple union all"}, + {"SELECT 1 INTERSECT SELECT 2", "simple intersect"}, + {"SELECT 1 INTERSECT ALL SELECT 2", "intersect all"}, + {"SELECT 1 EXCEPT SELECT 2", "simple except"}, + {"SELECT 1 EXCEPT ALL SELECT 2", "except all"}, + {"SELECT 1 UNION SELECT 2 INTERSECT SELECT 3", "union + intersect precedence"}, + {"SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3", "except + intersect precedence"}, + {"(SELECT 1) UNION (SELECT 2)", "parenthesized"}, + {"(SELECT 1 UNION SELECT 2) INTERSECT SELECT 3", "paren override"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a", "trailing order by"}, + {"SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10", "trailing limit"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5", "trailing order by + limit"}, + {"SELECT * FROM t1 WHERE x = 1 UNION SELECT * FROM t2 WHERE y = 2", "union with where"}, +}; + +TEST(MySQLCompoundBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : compound_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +TEST(PgSQLCompoundBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : compound_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const CompoundTestCase compound_roundtrip_cases[] = { + {"SELECT 1 UNION SELECT 2", "simple union"}, + {"SELECT 1 UNION ALL SELECT 2", "union all"}, + {"SELECT 1 INTERSECT SELECT 2", "intersect"}, + {"SELECT 1 EXCEPT SELECT 2", "except"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a", "with order by"}, + {"SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10", "with limit"}, +}; + +TEST(MySQLCompoundRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : compound_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} +``` + +- [ ] **Step 2: Add test_compound.cpp to Makefile.new** + +Add `$(TEST_DIR)/test_compound.cpp \` to the `TEST_SRCS` list. + +- [ ] **Step 3: Implement CompoundQueryParser class** + +Create `include/sql_parser/compound_query_parser.h`: +```cpp +#ifndef SQL_PARSER_COMPOUND_QUERY_PARSER_H +#define SQL_PARSER_COMPOUND_QUERY_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/select_parser.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +// Flag on NODE_SET_OPERATION to indicate ALL +static constexpr uint16_t FLAG_SET_OP_ALL = 0x01; + +template +class CompoundQueryParser { +public: + CompoundQueryParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {} + + // Parse a compound query (or a plain SELECT if no set operator follows). + // Returns NODE_SELECT_STMT for plain selects, NODE_COMPOUND_QUERY for compounds. + AstNode* parse(); + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + + // Precedence levels + static constexpr int PREC_UNION_EXCEPT = 1; + static constexpr int PREC_INTERSECT = 2; + + // Parse a compound expression with minimum precedence (Pratt-style) + AstNode* parse_compound_expr(int min_prec); + + // Parse a single operand: parenthesized compound or plain SELECT + AstNode* parse_operand(); + + // Get the precedence of a set operator token, or 0 if not a set operator + static int get_set_op_precedence(TokenType type); + + // Check if a token is a set operator + static bool is_set_operator(TokenType type); + + // Parse trailing ORDER BY for compound result + AstNode* parse_order_by(); + + // Parse trailing LIMIT for compound result + AstNode* parse_limit(); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_COMPOUND_QUERY_PARSER_H +``` + +- [ ] **Step 4: Implement CompoundQueryParser methods** + +Key implementation logic for `parse()`: + +``` +parse(): + 1. result = parse_compound_expr(0) + 2. if result is NODE_SET_OPERATION: + wrap in NODE_COMPOUND_QUERY + parse trailing ORDER BY / LIMIT as children of COMPOUND_QUERY + 3. if result is NODE_SELECT_STMT: + return as-is (no compound wrapper for plain selects) + +parse_compound_expr(min_prec): + 1. left = parse_operand() + 2. while (peek is set operator AND get_set_op_precedence(peek) > min_prec): + op_token = consume set operator + consume optional ALL + right = parse_compound_expr(get_set_op_precedence(op_token)) + left = make NODE_SET_OPERATION with left, right as children + 3. return left + +parse_operand(): + 1. if peek is '(': + consume '(' + if next is SELECT: inner = parse_compound_expr(0) + consume ')' + return inner + 2. else: + return SelectParser(tok_, arena_).parse() + +get_set_op_precedence(type): + TK_UNION, TK_EXCEPT => PREC_UNION_EXCEPT + TK_INTERSECT => PREC_INTERSECT + otherwise => 0 +``` + +The NODE_SET_OPERATION node's `value` field stores the operator text ("UNION", "UNION ALL", "INTERSECT", etc.). The FLAG_SET_OP_ALL flag distinguishes ALL variants. + +--- + +### Task 3: Emitter Support for Compound Query Nodes + +**Files:** +- Modify: `include/sql_parser/emitter.h` + +- [ ] **Step 1: Add emit methods for compound nodes** + +Add cases to the `emit_node()` switch: +```cpp +case NodeType::NODE_COMPOUND_QUERY: emit_compound_query(node); break; +case NodeType::NODE_SET_OPERATION: emit_set_operation(node); break; +``` + +- [ ] **Step 2: Implement emit methods** + +```cpp +void emit_compound_query(const AstNode* node); +void emit_set_operation(const AstNode* node); +``` + +`emit_compound_query()`: emit the top-level set operation child, then trailing ORDER BY and LIMIT children. + +`emit_set_operation()`: emit left child, then operator text (from node value), then right child. The operator text includes the ALL modifier if FLAG_SET_OP_ALL is set. + +--- + +### Task 4: Integration into parse_select() + +**Files:** +- Modify: `src/sql_parser/parser.cpp` + +- [ ] **Step 1: Update `parse_select()` to use CompoundQueryParser** + +Replace the current `parse_select()` implementation: + +```cpp +template +ParseResult Parser::parse_select() { + ParseResult r; + r.stmt_type = StmtType::SELECT; + + CompoundQueryParser compound_parser(tokenizer_, arena_); + AstNode* ast = compound_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 2: Add `#include "sql_parser/compound_query_parser.h"` to `parser.cpp`** + +- [ ] **Step 3: Handle parenthesized SELECT at classifier level** + +The classifier currently only dispatches on `TK_SELECT`. A query starting with `(SELECT ...` would not be recognized. Add handling for `TK_LPAREN` in the classifier: peek ahead; if next token is `SELECT` (or another `TK_LPAREN`), dispatch to `parse_select()`. The `CompoundQueryParser` will handle the parenthesized form. + +```cpp +case TokenType::TK_LPAREN: { + // Peek to see if this is a parenthesized SELECT / compound query + Token next = tokenizer_.peek(); + if (next.type == TokenType::TK_SELECT || next.type == TokenType::TK_LPAREN) { + // Put the LPAREN back by adjusting state, or handle in CompoundQueryParser + return parse_select_from_lparen(); + } + return extract_unknown(first); +} +``` + +Note: The exact mechanism for "putting back" the `(` depends on how the tokenizer works. The simplest approach is for `parse_select()` to handle the case where the first token was `(` instead of `SELECT` -- pass a flag or have `CompoundQueryParser` check for leading `(`. + +- [ ] **Step 4: Run all tests** + +```bash +make -f Makefile.new test +``` + +All existing SELECT tests must still pass (CompoundQueryParser returns bare NODE_SELECT_STMT for non-compound queries). New compound tests should pass. diff --git a/docs/superpowers/plans/2026-03-24-plan11-query-digest.md b/docs/superpowers/plans/2026-03-24-plan11-query-digest.md new file mode 100644 index 0000000..25d3ba1 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-plan11-query-digest.md @@ -0,0 +1,577 @@ +# Query Digest / Normalization Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build a query digest module that normalizes SQL queries (literals to `?`, IN list collapsing, keyword uppercasing) and produces a 64-bit hash for query rules matching. Works for all statement types: AST-based for Tier 1 statements, token-level fallback for Tier 2/unknown statements. + +**Architecture:** The digest system has two paths. For Tier 1 statements with a full AST, the existing `Emitter` is extended with an `EmitMode::DIGEST` flag that changes how literals, IN lists, and keywords are emitted. For Tier 2 statements (or parse failures), a token-level walker normalizes directly from the token stream. Both paths produce a normalized string and a 64-bit FNV-1a hash. The `Digest` class provides the public API, wrapping both paths behind `compute(ast)` and `compute(sql, len)` methods. + +**Tech Stack:** C++17, existing parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` + +--- + +## Scope + +This plan builds: +1. `DigestResult` struct with normalized string and 64-bit hash +2. `EmitMode::DIGEST` flag on `Emitter` with modified emit behavior +3. Token-level digest fallback for Tier 2 statements +4. FNV-1a hash computation (incremental) +5. `Digest` public API class +6. Comprehensive tests: same-query-different-literals, IN collapsing, cross-tier consistency + +**Closes:** #9 + +**Dependencies:** Benefits from Plans 7-10 being complete (more Tier 1 statements to test AST-based digest), but works independently via token-level fallback for any statement type. + +--- + +## File Structure + +``` +include/sql_parser/ + digest.h — (create) Digest class, DigestResult, FNV-1a hash + emitter.h — (modify) add EmitMode enum, modify literal/in-list emission + +tests/ + test_digest.cpp — (create) digest tests + +Makefile.new — (modify) add test_digest.cpp to TEST_SRCS +``` + +--- + +### Task 1: DigestResult and FNV-1a Hash + +**Files:** +- Create: `include/sql_parser/digest.h` + +- [ ] **Step 1: Define DigestResult and hash function** + +Create `include/sql_parser/digest.h`: +```cpp +#ifndef SQL_PARSER_DIGEST_H +#define SQL_PARSER_DIGEST_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/emitter.h" +#include "sql_parser/string_builder.h" +#include + +namespace sql_parser { + +struct DigestResult { + StringRef normalized; // "SELECT * FROM t WHERE id = ?" + uint64_t hash; // 64-bit FNV-1a hash +}; + +// FNV-1a 64-bit hash — simple, fast, no external dependency +struct FnvHash { + static constexpr uint64_t FNV_OFFSET_BASIS = 14695981039346656037ULL; + static constexpr uint64_t FNV_PRIME = 1099511628211ULL; + + uint64_t state = FNV_OFFSET_BASIS; + + void update(const char* data, size_t len) { + for (size_t i = 0; i < len; ++i) { + state ^= static_cast(static_cast(data[i])); + state *= FNV_PRIME; + } + } + + void update_char(char c) { + state ^= static_cast(static_cast(c)); + state *= FNV_PRIME; + } + + uint64_t finish() const { return state; } +}; + +template +class Digest { +public: + explicit Digest(Arena& arena) : arena_(arena) {} + + // From a parsed AST (Tier 1) — uses Emitter in DIGEST mode + DigestResult compute(const AstNode* ast); + + // From raw SQL (works for any statement) — uses token-level fallback + DigestResult compute(const char* sql, size_t len); + +private: + Arena& arena_; + + // Token-level digest: walk tokens, normalize, hash + DigestResult compute_token_level(const char* sql, size_t len); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_DIGEST_H +``` + +--- + +### Task 2: Emitter Digest Mode + +**Files:** +- Modify: `include/sql_parser/emitter.h` + +- [ ] **Step 1: Add EmitMode enum and constructor parameter** + +Add to `emitter.h`: +```cpp +enum class EmitMode : uint8_t { NORMAL, DIGEST }; +``` + +Update the `Emitter` constructor: +```cpp +explicit Emitter(Arena& arena, EmitMode mode = EmitMode::NORMAL, + const ParamBindings* bindings = nullptr) + : sb_(arena), bindings_(bindings), placeholder_index_(0), mode_(mode) {} +``` + +Add `EmitMode mode_` as a private member. + +- [ ] **Step 2: Modify emit_value for literals in DIGEST mode** + +In `emit_node()`, change literal handling: +```cpp +case NodeType::NODE_LITERAL_INT: +case NodeType::NODE_LITERAL_FLOAT: + if (mode_ == EmitMode::DIGEST) { sb_.append_char('?'); break; } + emit_value(node); break; + +case NodeType::NODE_LITERAL_STRING: + if (mode_ == EmitMode::DIGEST) { sb_.append_char('?'); break; } + emit_string_literal(node); break; +``` + +- [ ] **Step 3: Modify emit_in_list for DIGEST mode** + +In `emit_in_list()`, when `mode_ == EmitMode::DIGEST`: +```cpp +void emit_in_list(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + sb_.append(" IN ("); + if (mode_ == EmitMode::DIGEST) { + sb_.append_char('?'); + } else { + bool first = true; + for (const AstNode* val = expr ? expr->next_sibling : nullptr; val; val = val->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(val); + } + } + sb_.append_char(')'); +} +``` + +- [ ] **Step 4: Keyword uppercasing in DIGEST mode** + +All keywords are already emitted as literals (e.g., `sb_.append("SELECT ")`) which are uppercase. Identifiers from the input that happen to be keywords are emitted via `emit_value()` which preserves original case. For digest mode, the token-level path handles uppercasing. For AST-based digest, since the emitter already uses uppercase keyword strings in `emit_*` methods, this is mostly free. The only concern is node values that contain lowercase keywords (e.g., a `join` type stored as-is). For full correctness, `emit_value()` in DIGEST mode should uppercase keyword-type nodes, but this can be deferred as the emitter's structural methods already use uppercase constants. + +--- + +### Task 3: Token-Level Digest Implementation + +**Files:** +- Modify: `include/sql_parser/digest.h` + +- [ ] **Step 1: Implement token-level digest** + +Add the `compute_token_level()` implementation to `digest.h`: + +```cpp +template +DigestResult Digest::compute_token_level(const char* sql, size_t len) { + Tokenizer tok; + tok.reset(sql, len); + StringBuilder sb(arena_); + FnvHash hasher; + bool first = true; + bool in_list_context = false; + bool last_was_placeholder = false; + + while (true) { + Token t = tok.next_token(); + if (t.type == TokenType::TK_EOF) break; + if (t.type == TokenType::TK_SEMICOLON) break; + + // Determine what to emit for this token + // ... (see implementation notes below) + } + + StringRef normalized = sb.finish(); + hasher.update(normalized.ptr, normalized.len); + return DigestResult{normalized, hasher.finish()}; +} +``` + +Key logic for each token type: +- **TK_INTEGER, TK_FLOAT, TK_STRING**: emit `?`. If inside an IN list and the previous emitted token was also `?`, skip (IN list collapsing). +- **Keywords**: emit uppercase form. Use the token's text but uppercased. +- **TK_IDENTIFIER**: emit as-is. +- **TK_COMMA in IN context**: if collapsing, skip the comma too. +- **All other tokens**: emit as-is with single space separation. + +Track `in_list_context` by watching for `IN` followed by `(`. Reset when `)` is seen. + +- [ ] **Step 2: Implement AST-based compute** + +```cpp +template +DigestResult Digest::compute(const AstNode* ast) { + Emitter emitter(arena_, EmitMode::DIGEST); + emitter.emit(ast); + StringRef normalized = emitter.result(); + FnvHash hasher; + hasher.update(normalized.ptr, normalized.len); + return DigestResult{normalized, hasher.finish()}; +} +``` + +- [ ] **Step 3: Implement raw SQL compute (with auto-detection)** + +```cpp +template +DigestResult Digest::compute(const char* sql, size_t len) { + // Always use token-level for the raw SQL path. + // Callers with a parsed AST should use compute(ast) directly. + return compute_token_level(sql, len); +} +``` + +Note: The `compute(sql, len)` overload is for convenience when no AST is available. Callers that have already parsed can use `compute(ast)` for potentially better normalization (AST-aware IN collapsing, correct VALUES row handling). + +--- + +### Task 4: Comprehensive Tests + +**Files:** +- Create: `tests/test_digest.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write digest tests** + +Create `tests/test_digest.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/digest.h" + +using namespace sql_parser; + +class MySQLDigestTest : public ::testing::Test { +protected: + Parser parser; + + DigestResult digest_ast(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + Digest digest(parser.arena()); + if (r.ast) { + return digest.compute(r.ast); + } + // Fallback to token-level + return digest.compute(sql, strlen(sql)); + } + + DigestResult digest_token(const char* sql) { + Digest digest(parser.arena()); + return digest.compute(sql, strlen(sql)); + } + + std::string normalized(const char* sql) { + auto d = digest_ast(sql); + return std::string(d.normalized.ptr, d.normalized.len); + } + + std::string normalized_token(const char* sql) { + auto d = digest_token(sql); + return std::string(d.normalized.ptr, d.normalized.len); + } +}; + +// ========== Literal normalization ========== + +TEST_F(MySQLDigestTest, IntegerLiteralNormalized) { + std::string out = normalized("SELECT * FROM t WHERE id = 42"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, FloatLiteralNormalized) { + std::string out = normalized("SELECT * FROM t WHERE price > 3.14"); + EXPECT_EQ(out, "SELECT * FROM t WHERE price > ?"); +} + +TEST_F(MySQLDigestTest, StringLiteralNormalized) { + std::string out = normalized("SELECT * FROM t WHERE name = 'Alice'"); + EXPECT_EQ(out, "SELECT * FROM t WHERE name = ?"); +} + +TEST_F(MySQLDigestTest, MultipleLiteralsNormalized) { + std::string out = normalized("SELECT * FROM t WHERE a = 1 AND b = 'x' AND c = 3.14"); + EXPECT_EQ(out, "SELECT * FROM t WHERE a = ? AND b = ? AND c = ?"); +} + +// ========== Same query, different literals => same hash ========== + +TEST_F(MySQLDigestTest, SameQueryDifferentInts) { + auto d1 = digest_ast("SELECT * FROM t WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM t WHERE id = 999"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, SameQueryDifferentStrings) { + auto d1 = digest_ast("SELECT * FROM t WHERE name = 'Alice'"); + auto d2 = digest_ast("SELECT * FROM t WHERE name = 'Bob'"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, DifferentQueriesDifferentHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM t WHERE name = 1"); + EXPECT_NE(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, DifferentTablesDifferentHash) { + auto d1 = digest_ast("SELECT * FROM users WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM orders WHERE id = 1"); + EXPECT_NE(d1.hash, d2.hash); +} + +// ========== IN list collapsing ========== + +TEST_F(MySQLDigestTest, InListCollapsed) { + std::string out = normalized("SELECT * FROM t WHERE id IN (1, 2, 3)"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id IN (?)"); +} + +TEST_F(MySQLDigestTest, InListDifferentSizesSameHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3)"); + auto d2 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, InListSingleValueSameHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id IN (1)"); + auto d2 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3)"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== Keyword uppercasing ========== + +TEST_F(MySQLDigestTest, KeywordsUppercased) { + // Token-level digest should uppercase keywords + std::string out = normalized_token("select * from t where id = 1"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id = ?"); +} + +// ========== Token-level fallback for Tier 2 ========== + +TEST_F(MySQLDigestTest, TokenLevelInsert) { + std::string out = normalized_token("INSERT INTO users (name) VALUES ('Alice')"); + EXPECT_EQ(out, "INSERT INTO users (name) VALUES (?)"); +} + +TEST_F(MySQLDigestTest, TokenLevelUpdate) { + std::string out = normalized_token("UPDATE users SET name = 'Bob' WHERE id = 42"); + EXPECT_EQ(out, "UPDATE users SET name = ? WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, TokenLevelDelete) { + std::string out = normalized_token("DELETE FROM users WHERE id = 1"); + EXPECT_EQ(out, "DELETE FROM users WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, TokenLevelCreateTable) { + // Tier 2 statement -- should still normalize literals + std::string out = normalized_token("CREATE TABLE t (id INT DEFAULT 0)"); + EXPECT_EQ(out, "CREATE TABLE t (id INT DEFAULT ?)"); +} + +TEST_F(MySQLDigestTest, TokenLevelInCollapsing) { + std::string out = normalized_token("SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id IN (?)"); +} + +// ========== Consistency: AST-based and token-level produce same hash ========== + +TEST_F(MySQLDigestTest, AstAndTokenLevelConsistentForSelect) { + const char* sql = "SELECT * FROM users WHERE id = 42"; + auto d_ast = digest_ast(sql); + auto d_tok = digest_token(sql); + // The normalized strings should be the same + EXPECT_EQ( + std::string(d_ast.normalized.ptr, d_ast.normalized.len), + std::string(d_tok.normalized.ptr, d_tok.normalized.len) + ); + // Therefore hashes should match + EXPECT_EQ(d_ast.hash, d_tok.hash); +} + +// ========== SET statement digest ========== + +TEST_F(MySQLDigestTest, SetVariableDigest) { + auto d1 = digest_ast("SET autocommit = 1"); + auto d2 = digest_ast("SET autocommit = 0"); + // SET with different values should produce same digest + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== NULL and boolean literals ========== + +TEST_F(MySQLDigestTest, NullPreserved) { + // NULL is not a literal value — it's a keyword, should not be replaced with ? + std::string out = normalized("SELECT * FROM t WHERE a IS NULL"); + EXPECT_EQ(out, "SELECT * FROM t WHERE a IS NULL"); +} + +TEST_F(MySQLDigestTest, LimitDigest) { + auto d1 = digest_ast("SELECT * FROM t LIMIT 10"); + auto d2 = digest_ast("SELECT * FROM t LIMIT 20"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== Placeholder passthrough ========== + +TEST_F(MySQLDigestTest, PlaceholderPassthrough) { + std::string out = normalized_token("SELECT * FROM t WHERE id = ?"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id = ?"); +} + +// ========== Bulk digest tests ========== + +struct DigestTestCase { + const char* sql1; + const char* sql2; + bool same_hash; + const char* description; +}; + +static const DigestTestCase digest_bulk_cases[] = { + {"SELECT * FROM t WHERE id = 1", "SELECT * FROM t WHERE id = 2", true, "different int literals"}, + {"SELECT * FROM t WHERE s = 'a'", "SELECT * FROM t WHERE s = 'b'", true, "different string literals"}, + {"SELECT * FROM t WHERE x = 1.5", "SELECT * FROM t WHERE x = 2.7", true, "different float literals"}, + {"SELECT * FROM t WHERE id IN (1,2)", "SELECT * FROM t WHERE id IN (1,2,3,4)", true, "in list sizes"}, + {"SELECT * FROM t LIMIT 10", "SELECT * FROM t LIMIT 100", true, "different limits"}, + {"SELECT * FROM t1 WHERE id = 1", "SELECT * FROM t2 WHERE id = 1", false, "different tables"}, + {"SELECT a FROM t WHERE id = 1", "SELECT b FROM t WHERE id = 1", false, "different columns"}, + {"SELECT * FROM t WHERE a = 1", "SELECT * FROM t WHERE b = 1", false, "different where cols"}, + {"SELECT * FROM t ORDER BY a", "SELECT * FROM t ORDER BY b", false, "different order"}, +}; + +TEST(MySQLDigestBulk, HashConsistency) { + Parser parser; + for (const auto& tc : digest_bulk_cases) { + auto r1 = parser.parse(tc.sql1, strlen(tc.sql1)); + Digest d1(parser.arena()); + auto dr1 = r1.ast ? d1.compute(r1.ast) : d1.compute(tc.sql1, strlen(tc.sql1)); + + auto r2 = parser.parse(tc.sql2, strlen(tc.sql2)); + Digest d2(parser.arena()); + auto dr2 = r2.ast ? d2.compute(r2.ast) : d2.compute(tc.sql2, strlen(tc.sql2)); + + if (tc.same_hash) { + EXPECT_EQ(dr1.hash, dr2.hash) + << "Expected same hash: " << tc.description + << "\n SQL1: " << tc.sql1 << "\n SQL2: " << tc.sql2 + << "\n Norm1: " << std::string(dr1.normalized.ptr, dr1.normalized.len) + << "\n Norm2: " << std::string(dr2.normalized.ptr, dr2.normalized.len); + } else { + EXPECT_NE(dr1.hash, dr2.hash) + << "Expected different hash: " << tc.description + << "\n SQL1: " << tc.sql1 << "\n SQL2: " << tc.sql2; + } + } +} + +// ========== PostgreSQL digest ========== + +class PgSQLDigestTest : public ::testing::Test { +protected: + Parser parser; + + DigestResult digest_token(const char* sql) { + Digest digest(parser.arena()); + return digest.compute(sql, strlen(sql)); + } + + std::string normalized_token(const char* sql) { + auto d = digest_token(sql); + return std::string(d.normalized.ptr, d.normalized.len); + } +}; + +TEST_F(PgSQLDigestTest, BasicDigest) { + std::string out = normalized_token("SELECT * FROM users WHERE id = 42"); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = ?"); +} + +TEST_F(PgSQLDigestTest, DollarPlaceholderPreserved) { + std::string out = normalized_token("SELECT * FROM users WHERE id = $1"); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = $1"); +} + +TEST_F(PgSQLDigestTest, InListCollapsed) { + std::string out = normalized_token("SELECT * FROM t WHERE id IN (1, 2, 3)"); + EXPECT_EQ(out, "SELECT * FROM t WHERE id IN (?)"); +} + +TEST_F(PgSQLDigestTest, ReturningDigest) { + std::string out = normalized_token("INSERT INTO t (a) VALUES (1) RETURNING *"); + EXPECT_EQ(out, "INSERT INTO t (a) VALUES (?) RETURNING *"); +} +``` + +- [ ] **Step 2: Add test_digest.cpp to Makefile.new** + +Add `$(TEST_DIR)/test_digest.cpp \` to the `TEST_SRCS` list. + +- [ ] **Step 3: Run all tests** + +```bash +make -f Makefile.new test +``` + +--- + +### Task 5: Edge Cases and Robustness + +- [ ] **Step 1: Handle VALUES rows in digest mode** + +For INSERT ... VALUES, the AST-based digest should normalize all value expressions to `?` but preserve the row structure. Multi-row INSERTs should collapse to a single row in the digest: `INSERT INTO t (a,b) VALUES (?,?)` regardless of how many rows the original had. This is the ProxySQL convention. + +Implementation: In `emit_values_clause()`, when `mode_ == EmitMode::DIGEST`, emit only the first `NODE_VALUES_ROW` child and skip the rest. + +- [ ] **Step 2: Ensure whitespace normalization** + +Both AST-based and token-level digest must produce single-space-separated output with no leading/trailing whitespace. The existing `StringBuilder` and emitter already produce well-formed output, but verify in tests. + +- [ ] **Step 3: Verify hash stability** + +FNV-1a is deterministic. Add a test that computes a digest for a known query and asserts the exact hash value to catch accidental changes: + +```cpp +TEST_F(MySQLDigestTest, HashStability) { + auto d = digest_ast("SELECT * FROM users WHERE id = 1"); + // The normalized form is "SELECT * FROM users WHERE id = ?" + // FNV-1a of that string should always be the same + EXPECT_NE(d.hash, 0ULL); + // Store the computed hash and assert it in future test runs + // (exact value depends on implementation, fill in after first run) +} +``` + +- [ ] **Step 4: Run all tests** + +```bash +make -f Makefile.new test +``` diff --git a/docs/superpowers/plans/2026-03-24-plan7-insert-parser.md b/docs/superpowers/plans/2026-03-24-plan7-insert-parser.md new file mode 100644 index 0000000..09fc15f --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-plan7-insert-parser.md @@ -0,0 +1,827 @@ +# INSERT/REPLACE Deep Parser Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Extract shared table reference parsing into `TableRefParser`, then build a full INSERT/REPLACE deep parser for MySQL and PostgreSQL with emitter support and comprehensive tests. + +**Architecture:** First, the table reference methods (`parse_from_clause`, `parse_table_reference`, `parse_join`, `parse_optional_alias`, `is_join_start`, `is_alias_start`) are extracted from `SelectParser` into a standalone `TableRefParser` utility class. Then, `InsertParser` is built as a header-only template following the same pattern as `SelectParser` and `SetParser`. It handles VALUES, SELECT, SET (MySQL), ON DUPLICATE KEY UPDATE (MySQL), ON CONFLICT (PostgreSQL), and RETURNING (PostgreSQL). The parser is integrated via `parser.cpp` by replacing the existing `extract_insert()` and `extract_replace()` Tier 2 extractors. + +**Tech Stack:** C++17, existing parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` + +--- + +## Scope + +This plan builds: +1. `TableRefParser` — shared FROM/JOIN/table ref parsing utility, extracted from SelectParser +2. `InsertParser` — full INSERT/REPLACE deep parser for both dialects +3. Emitter extensions for all INSERT-related node types +4. Classifier updates to route TK_INSERT and TK_REPLACE to the deep parser +5. Comprehensive tests including round-trip tests + +**Closes:** #5 + +**Dependencies:** None (this is the first plan in the Tier 1 promotions series). + +--- + +## File Structure + +``` +include/sql_parser/ + table_ref_parser.h — (create) shared FROM/JOIN/table reference parsing + insert_parser.h — (create) INSERT/REPLACE statement parser + select_parser.h — (modify) replace private methods with TableRefParser calls + emitter.h — (modify) add INSERT emit methods + common.h — (modify) add new NodeType values + token.h — (modify) add new TokenType values + +src/sql_parser/ + parser.cpp — (modify) replace extract_insert/extract_replace with parse_insert() + +include/sql_parser/ + parser.h — (modify) add parse_insert() declaration, remove extract_insert/extract_replace + +include/sql_parser/ + keywords_mysql.h — (modify) add new keywords + keywords_pgsql.h — (modify) add new keywords + +tests/ + test_insert.cpp — (create) INSERT parser tests + +Makefile.new — (modify) add test_insert.cpp to TEST_SRCS +``` + +--- + +### Task 1: Extract TableRefParser from SelectParser + +**Files:** +- Create: `include/sql_parser/table_ref_parser.h` +- Modify: `include/sql_parser/select_parser.h` + +This task extracts all table reference parsing logic into a shared utility class. SelectParser's private methods are replaced with calls to TableRefParser. All existing tests must continue to pass. + +- [ ] **Step 1: Create `table_ref_parser.h` with extracted methods** + +Create `include/sql_parser/table_ref_parser.h`: +```cpp +#ifndef SQL_PARSER_TABLE_REF_PARSER_H +#define SQL_PARSER_TABLE_REF_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class TableRefParser { +public: + TableRefParser(Tokenizer& tokenizer, Arena& arena, + ExpressionParser& expr_parser) + : tok_(tokenizer), arena_(arena), expr_parser_(expr_parser) {} + + // Parse a FROM clause: table_ref [, table_ref | JOIN ...]* + AstNode* parse_from_clause(); + + // Parse a single table reference (simple name, qualified name, subquery) + AstNode* parse_table_reference(); + + // Parse a JOIN clause (modifiers + JOIN keyword already detected by caller) + AstNode* parse_join(AstNode* left_ref); + + // Parse optional alias (AS name or implicit alias) + void parse_optional_alias(AstNode* parent); + + // Check if a token can start a JOIN + static bool is_join_start(TokenType type); + + // Check if a token can start an implicit alias + static bool is_alias_start(TokenType type); + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser& expr_parser_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TABLE_REF_PARSER_H +``` + +The implementation follows the exact same logic currently in `SelectParser`. Move the method bodies from `select_parser.h` into this class. + +- [ ] **Step 2: Update SelectParser to use TableRefParser** + +Modify `include/sql_parser/select_parser.h`: +- Add `#include "sql_parser/table_ref_parser.h"` +- Add a `TableRefParser table_ref_parser_` member, initialized in the constructor +- Replace all calls to the private `parse_from_clause()`, `parse_table_reference()`, `parse_join()`, `parse_optional_alias()` with calls to `table_ref_parser_.parse_from_clause()`, etc. +- Remove the private method implementations that were moved +- Keep `is_alias_start()` calls in `parse_select_item()` pointing to `TableRefParser::is_alias_start()` + +- [ ] **Step 3: Run all existing tests — they must pass unchanged** + +```bash +make -f Makefile.new test +``` + +No test changes should be needed. This is a pure refactoring step. + +--- + +### Task 2: Add New Tokens and Node Types + +**Files:** +- Modify: `include/sql_parser/token.h` +- Modify: `include/sql_parser/common.h` +- Modify: `include/sql_parser/keywords_mysql.h` +- Modify: `include/sql_parser/keywords_pgsql.h` + +- [ ] **Step 1: Add new token types to `token.h`** + +Add after the existing `TK_SQL_CALC_FOUND_ROWS` line: +```cpp +TK_DELAYED, +TK_HIGH_PRIORITY, +TK_DUPLICATE, +TK_KEY, +TK_CONFLICT, +TK_DO, +TK_NOTHING, +TK_RETURNING, +TK_CONSTRAINT, +``` + +- [ ] **Step 2: Add new node types to `common.h`** + +Add after the existing `NODE_CASE_WHEN` entry: +```cpp +// INSERT nodes +NODE_INSERT_STMT, +NODE_INSERT_COLUMNS, // (col1, col2, ...) +NODE_VALUES_CLAUSE, // VALUES keyword wrapper +NODE_VALUES_ROW, // single (val1, val2, ...) row +NODE_INSERT_SET_CLAUSE, // MySQL INSERT ... SET col=val form +NODE_ON_DUPLICATE_KEY, // MySQL ON DUPLICATE KEY UPDATE +NODE_ON_CONFLICT, // PostgreSQL ON CONFLICT +NODE_CONFLICT_TARGET, // PostgreSQL conflict target (cols or ON CONSTRAINT) +NODE_CONFLICT_ACTION, // DO UPDATE SET ... or DO NOTHING +NODE_RETURNING_CLAUSE, // PostgreSQL RETURNING expr_list + +// Shared +NODE_STMT_OPTIONS, // LOW_PRIORITY, IGNORE, QUICK, DELAYED, etc. +NODE_UPDATE_SET_ITEM, // single col=expr pair (shared by INSERT SET and UPDATE SET) +``` + +- [ ] **Step 3: Register keywords in keyword tables** + +Add to `keywords_mysql.h`: DELAYED, HIGH_PRIORITY, DUPLICATE, KEY, CONFLICT, DO, NOTHING, RETURNING, CONSTRAINT. + +Add to `keywords_pgsql.h`: CONFLICT, DO, NOTHING, RETURNING, CONSTRAINT. + +- [ ] **Step 4: Update `is_alias_start()` blocklist in `TableRefParser`** + +Add to the blocklist: `TK_RETURNING`, `TK_CONFLICT`, `TK_DO`, `TK_NOTHING`, `TK_DUPLICATE`. + +--- + +### Task 3: INSERT Parser — Core Structure, VALUES, Column List + +**Files:** +- Create: `include/sql_parser/insert_parser.h` +- Create: `tests/test_insert.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write tests for basic INSERT parsing** + +Create `tests/test_insert.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLInsertTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic INSERT ========== + +TEST_F(MySQLInsertTest, SimpleInsert) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice')", 49); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_INSERT_STMT); +} + +TEST_F(MySQLInsertTest, InsertWithoutInto) { + auto r = parser.parse("INSERT users (id) VALUES (1)", 28); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertWithoutColumnList) { + auto r = parser.parse("INSERT INTO users VALUES (1, 'Alice')", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* cols = find_child(r.ast, NodeType::NODE_INSERT_COLUMNS); + EXPECT_EQ(cols, nullptr); // no column list +} + +TEST_F(MySQLInsertTest, InsertColumnList) { + auto r = parser.parse("INSERT INTO users (id, name, email) VALUES (1, 'Alice', 'a@b.com')", 67); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* cols = find_child(r.ast, NodeType::NODE_INSERT_COLUMNS); + ASSERT_NE(cols, nullptr); + EXPECT_EQ(child_count(cols), 3); +} + +TEST_F(MySQLInsertTest, InsertMultiRow) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')", 60); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* values = find_child(r.ast, NodeType::NODE_VALUES_CLAUSE); + ASSERT_NE(values, nullptr); + EXPECT_EQ(child_count(values), 2); // two rows +} + +TEST_F(MySQLInsertTest, InsertTableRef) { + auto r = parser.parse("INSERT INTO mydb.users (id) VALUES (1)", 39); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* tref = find_child(r.ast, NodeType::NODE_TABLE_REF); + ASSERT_NE(tref, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLInsertTest, InsertLowPriority) { + auto r = parser.parse("INSERT LOW_PRIORITY INTO users (id) VALUES (1)", 47); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLInsertTest, InsertDelayed) { + auto r = parser.parse("INSERT DELAYED INTO users (id) VALUES (1)", 42); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertHighPriority) { + auto r = parser.parse("INSERT HIGH_PRIORITY INTO users (id) VALUES (1)", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertIgnore) { + auto r = parser.parse("INSERT IGNORE INTO users (id) VALUES (1)", 41); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertLowPriorityIgnore) { + auto r = parser.parse("INSERT LOW_PRIORITY IGNORE INTO users (id) VALUES (1)", 54); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== INSERT ... SELECT ========== + +TEST_F(MySQLInsertTest, InsertSelect) { + auto r = parser.parse("INSERT INTO users (id, name) SELECT id, name FROM temp_users", 60); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* select = find_child(r.ast, NodeType::NODE_SELECT_STMT); + ASSERT_NE(select, nullptr); +} + +TEST_F(MySQLInsertTest, InsertSelectWithWhere) { + const char* sql = "INSERT INTO users (id, name) SELECT id, name FROM temp WHERE active = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL INSERT ... SET ========== + +TEST_F(MySQLInsertTest, InsertSet) { + auto r = parser.parse("INSERT INTO users SET id = 1, name = 'Alice'", 45); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* set_clause = find_child(r.ast, NodeType::NODE_INSERT_SET_CLAUSE); + ASSERT_NE(set_clause, nullptr); + EXPECT_EQ(child_count(set_clause), 2); // two col=val pairs +} + +// ========== ON DUPLICATE KEY UPDATE ========== + +TEST_F(MySQLInsertTest, OnDuplicateKey) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') ON DUPLICATE KEY UPDATE name = 'Alice2'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* odku = find_child(r.ast, NodeType::NODE_ON_DUPLICATE_KEY); + ASSERT_NE(odku, nullptr); +} + +TEST_F(MySQLInsertTest, OnDuplicateKeyMultiple) { + const char* sql = "INSERT INTO users (id, name, email) VALUES (1, 'Alice', 'a@b.com') " + "ON DUPLICATE KEY UPDATE name = VALUES(name), email = VALUES(email)"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* odku = find_child(r.ast, NodeType::NODE_ON_DUPLICATE_KEY); + ASSERT_NE(odku, nullptr); + EXPECT_EQ(child_count(odku), 2); +} + +// ========== REPLACE ========== + +TEST_F(MySQLInsertTest, ReplaceSimple) { + auto r = parser.parse("REPLACE INTO users (id, name) VALUES (1, 'Alice')", 50); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_INSERT_STMT); + // REPLACE flag should be set in flags + EXPECT_NE(r.ast->flags & 0x01, 0); // FLAG_REPLACE = 0x01 +} + +TEST_F(MySQLInsertTest, ReplaceLowPriority) { + auto r = parser.parse("REPLACE LOW_PRIORITY INTO users (id) VALUES (1)", 48); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); +} + +TEST_F(MySQLInsertTest, ReplaceDelayed) { + auto r = parser.parse("REPLACE DELAYED INTO users (id) VALUES (1)", 43); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== PostgreSQL INSERT ========== + +class PgSQLInsertTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLInsertTest, SimpleInsert) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice')", 49); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, DefaultValues) { + auto r = parser.parse("INSERT INTO users DEFAULT VALUES", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoNothing) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT DO NOTHING"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* oc = find_child(r.ast, NodeType::NODE_ON_CONFLICT); + ASSERT_NE(oc, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoUpdate) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = 'Alice2'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* oc = find_child(r.ast, NodeType::NODE_ON_CONFLICT); + ASSERT_NE(oc, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictOnConstraint) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT ON CONSTRAINT users_pkey DO NOTHING"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoUpdateWhere) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name WHERE users.active = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, Returning) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLInsertTest, ReturningStar) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictWithReturning) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ON_CONFLICT), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct InsertTestCase { + const char* sql; + const char* description; +}; + +static const InsertTestCase mysql_insert_bulk_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple single column"}, + {"INSERT INTO t (a, b) VALUES (1, 2)", "two columns"}, + {"INSERT INTO t (a, b, c) VALUES (1, 2, 3)", "three columns"}, + {"INSERT INTO t VALUES (1, 2)", "no column list"}, + {"INSERT t (a) VALUES (1)", "without INTO"}, + {"INSERT INTO db.t (a) VALUES (1)", "qualified table"}, + {"INSERT INTO t (a) VALUES (1), (2), (3)", "multi-row"}, + {"INSERT INTO t (a, b) VALUES (1, 'x'), (2, 'y')", "multi-row with strings"}, + {"INSERT LOW_PRIORITY INTO t (a) VALUES (1)", "low priority"}, + {"INSERT DELAYED INTO t (a) VALUES (1)", "delayed"}, + {"INSERT HIGH_PRIORITY INTO t (a) VALUES (1)", "high priority"}, + {"INSERT IGNORE INTO t (a) VALUES (1)", "ignore"}, + {"INSERT LOW_PRIORITY IGNORE INTO t (a) VALUES (1)", "low priority ignore"}, + {"INSERT INTO t SET a = 1", "set form single"}, + {"INSERT INTO t SET a = 1, b = 'x'", "set form multiple"}, + {"INSERT INTO t (a) SELECT a FROM t2", "insert select"}, + {"INSERT INTO t (a, b) SELECT a, b FROM t2 WHERE c > 0", "insert select with where"}, + {"INSERT INTO t (a) VALUES (1) ON DUPLICATE KEY UPDATE a = 2", "on duplicate key"}, + {"INSERT INTO t (a, b) VALUES (1, 'x') ON DUPLICATE KEY UPDATE b = VALUES(b)", "odku values()"}, + {"INSERT INTO t (a, b) VALUES (1, 'x') ON DUPLICATE KEY UPDATE a = a + 1, b = 'y'", "odku multi"}, + {"REPLACE INTO t (a) VALUES (1)", "replace simple"}, + {"REPLACE INTO t (a, b) VALUES (1, 2)", "replace two cols"}, + {"REPLACE LOW_PRIORITY INTO t (a) VALUES (1)", "replace low priority"}, + {"REPLACE INTO t SET a = 1", "replace set form"}, +}; + +TEST(MySQLInsertBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_insert_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const InsertTestCase pgsql_insert_bulk_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t (a, b) VALUES (1, 2)", "two columns"}, + {"INSERT INTO t VALUES (1, 2)", "no column list"}, + {"INSERT INTO t DEFAULT VALUES", "default values"}, + {"INSERT INTO t (a) VALUES (1), (2)", "multi-row"}, + {"INSERT INTO t (a) SELECT a FROM t2", "insert select"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING", "on conflict do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO NOTHING", "on conflict col do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2", "on conflict do update"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT ON CONSTRAINT t_pkey DO NOTHING", "on constraint"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a", "excluded ref"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2 WHERE t.b > 0", "do update where"}, + {"INSERT INTO t (a) VALUES (1) RETURNING a", "returning single"}, + {"INSERT INTO t (a) VALUES (1) RETURNING *", "returning star"}, + {"INSERT INTO t (a) VALUES (1) RETURNING a, b", "returning multi"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING RETURNING *", "conflict + returning"}, +}; + +TEST(PgSQLInsertBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_insert_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const InsertTestCase mysql_insert_roundtrip_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t (a, b) VALUES (1, 'x')", "two cols with string"}, + {"INSERT INTO t (a) VALUES (1), (2), (3)", "multi-row"}, + {"INSERT INTO t SET a = 1, b = 'x'", "set form"}, + {"INSERT LOW_PRIORITY IGNORE INTO t (a) VALUES (1)", "options"}, + {"INSERT INTO t (a) VALUES (1) ON DUPLICATE KEY UPDATE a = 2", "odku"}, + {"REPLACE INTO t (a, b) VALUES (1, 2)", "replace"}, +}; + +TEST(MySQLInsertRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_insert_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const InsertTestCase pgsql_insert_roundtrip_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t DEFAULT VALUES", "default values"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING", "on conflict do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2", "on conflict do update"}, + {"INSERT INTO t (a) VALUES (1) RETURNING *", "returning star"}, +}; + +TEST(PgSQLInsertRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_insert_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} +``` + +- [ ] **Step 2: Add test_insert.cpp to Makefile.new** + +Add `$(TEST_DIR)/test_insert.cpp \` to the `TEST_SRCS` list. + +- [ ] **Step 3: Implement InsertParser class declaration** + +Create `include/sql_parser/insert_parser.h`: +```cpp +#ifndef SQL_PARSER_INSERT_PARSER_H +#define SQL_PARSER_INSERT_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" +#include "sql_parser/select_parser.h" + +namespace sql_parser { + +// Flag on NODE_INSERT_STMT to indicate REPLACE +static constexpr uint16_t FLAG_REPLACE = 0x01; + +template +class InsertParser { +public: + InsertParser(Tokenizer& tokenizer, Arena& arena, bool is_replace = false) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_), + is_replace_(is_replace) {} + + // Parse INSERT/REPLACE statement (INSERT/REPLACE keyword already consumed). + AstNode* parse(); + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + bool is_replace_; + + // Parse MySQL options: LOW_PRIORITY, DELAYED, HIGH_PRIORITY, IGNORE + AstNode* parse_stmt_options(); + + // Parse column list: (col1, col2, ...) + AstNode* parse_column_list(); + + // Parse VALUES clause: VALUES (row1), (row2), ... + AstNode* parse_values_clause(); + + // Parse a single values row: (expr, expr, ...) + AstNode* parse_values_row(); + + // Parse MySQL SET form: SET col=val, col=val, ... + AstNode* parse_insert_set_clause(); + + // Parse a single col=expr pair + AstNode* parse_set_item(); + + // Parse MySQL ON DUPLICATE KEY UPDATE + AstNode* parse_on_duplicate_key(); + + // Parse PostgreSQL ON CONFLICT + AstNode* parse_on_conflict(); + + // Parse PostgreSQL RETURNING + AstNode* parse_returning(); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_INSERT_PARSER_H +``` + +- [ ] **Step 4: Implement InsertParser parse methods** + +Implement all methods in the header following the spec's syntax. Key logic: +- `parse()`: options -> table ref -> column list -> (VALUES | SELECT | SET | DEFAULT VALUES) -> (ON DUPLICATE KEY UPDATE | ON CONFLICT) -> RETURNING +- Use `if constexpr (D == Dialect::MySQL)` for MySQL-only features (SET form, ON DUPLICATE KEY, DELAYED/HIGH_PRIORITY) +- Use `if constexpr (D == Dialect::PostgreSQL)` for PostgreSQL-only features (ON CONFLICT, RETURNING, DEFAULT VALUES) +- For INSERT ... SELECT, instantiate `SelectParser` and call its `parse()` method +- Refer to `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` for full syntax details + +--- + +### Task 4: Emitter Support for INSERT Nodes + +**Files:** +- Modify: `include/sql_parser/emitter.h` + +- [ ] **Step 1: Add emit methods for all INSERT node types** + +Add cases to the `emit_node()` switch: +```cpp +case NodeType::NODE_INSERT_STMT: emit_insert_stmt(node); break; +case NodeType::NODE_INSERT_COLUMNS: emit_insert_columns(node); break; +case NodeType::NODE_VALUES_CLAUSE: emit_values_clause(node); break; +case NodeType::NODE_VALUES_ROW: emit_values_row(node); break; +case NodeType::NODE_INSERT_SET_CLAUSE: emit_insert_set_clause(node); break; +case NodeType::NODE_ON_DUPLICATE_KEY: emit_on_duplicate_key(node); break; +case NodeType::NODE_ON_CONFLICT: emit_on_conflict(node); break; +case NodeType::NODE_CONFLICT_TARGET: emit_conflict_target(node); break; +case NodeType::NODE_CONFLICT_ACTION: emit_conflict_action(node); break; +case NodeType::NODE_RETURNING_CLAUSE: emit_returning(node); break; +case NodeType::NODE_STMT_OPTIONS: emit_stmt_options(node); break; +case NodeType::NODE_UPDATE_SET_ITEM: emit_update_set_item(node); break; +``` + +- [ ] **Step 2: Implement each emit method** + +Key emit methods (signatures only — implementer fills in bodies following existing emitter patterns): + +```cpp +void emit_insert_stmt(const AstNode* node); // INSERT/REPLACE INTO table ... +void emit_insert_columns(const AstNode* node); // (col1, col2, ...) +void emit_values_clause(const AstNode* node); // VALUES (row), (row) +void emit_values_row(const AstNode* node); // (val, val, ...) +void emit_insert_set_clause(const AstNode* node); // SET col=val, col=val +void emit_on_duplicate_key(const AstNode* node); // ON DUPLICATE KEY UPDATE col=val +void emit_on_conflict(const AstNode* node); // ON CONFLICT ... DO ... +void emit_conflict_target(const AstNode* node); // (cols) or ON CONSTRAINT name +void emit_conflict_action(const AstNode* node); // DO UPDATE SET ... or DO NOTHING +void emit_returning(const AstNode* node); // RETURNING expr, expr +void emit_stmt_options(const AstNode* node); // LOW_PRIORITY IGNORE etc. +void emit_update_set_item(const AstNode* node); // col = expr +``` + +`emit_insert_stmt()` must check the FLAG_REPLACE flag to emit "REPLACE" vs "INSERT". + +--- + +### Task 5: Classifier Integration + +**Files:** +- Modify: `include/sql_parser/parser.h` +- Modify: `src/sql_parser/parser.cpp` + +- [ ] **Step 1: Add `parse_insert()` method declaration to `parser.h`** + +Add to the private section: +```cpp +ParseResult parse_insert(bool is_replace = false); +``` + +Remove (or keep as fallback, but they will no longer be called): +```cpp +// These are superseded by parse_insert(): +// ParseResult extract_insert(const Token& first); +// ParseResult extract_replace(const Token& first); +``` + +- [ ] **Step 2: Implement `parse_insert()` in `parser.cpp`** + +```cpp +template +ParseResult Parser::parse_insert(bool is_replace) { + ParseResult r; + r.stmt_type = is_replace ? StmtType::REPLACE : StmtType::INSERT; + + InsertParser insert_parser(tokenizer_, arena_, is_replace); + AstNode* ast = insert_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 3: Update `classify_and_dispatch()` switch** + +Replace: +```cpp +case TokenType::TK_INSERT: return extract_insert(first); +// ... +case TokenType::TK_REPLACE: return extract_replace(first); +``` + +With: +```cpp +case TokenType::TK_INSERT: return parse_insert(false); +case TokenType::TK_REPLACE: return parse_insert(true); +``` + +- [ ] **Step 4: Add `#include "sql_parser/insert_parser.h"` to `parser.cpp`** + +- [ ] **Step 5: Run all tests** + +```bash +make -f Makefile.new test +``` + +All existing tests plus new INSERT tests should pass. diff --git a/docs/superpowers/plans/2026-03-24-plan8-update-parser.md b/docs/superpowers/plans/2026-03-24-plan8-update-parser.md new file mode 100644 index 0000000..29b55de --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-plan8-update-parser.md @@ -0,0 +1,587 @@ +# UPDATE Deep Parser Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build a full UPDATE deep parser for MySQL and PostgreSQL that handles multi-table JOINs (MySQL), FROM clause (PostgreSQL), SET assignments, ORDER BY/LIMIT (MySQL), and RETURNING (PostgreSQL). + +**Architecture:** `UpdateParser` is a header-only template class following the established pattern. It uses `TableRefParser` (from Plan 7) for table reference parsing, `ExpressionParser` for all expression positions, and reuses `NODE_UPDATE_SET_ITEM` (also from Plan 7) for SET assignments. The parser is integrated via `parser.cpp` by replacing the existing `extract_update()` Tier 2 extractor. + +**Tech Stack:** C++17, existing parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` + +--- + +## Scope + +This plan builds: +1. `UpdateParser` — full UPDATE deep parser for both dialects +2. Emitter extensions for UPDATE-specific node types +3. Classifier update to route TK_UPDATE to the deep parser +4. Comprehensive tests including round-trip tests + +**Closes:** #6 + +**Dependencies:** Plan 7 (provides `TableRefParser`, `NODE_STMT_OPTIONS`, `NODE_UPDATE_SET_ITEM`, `NODE_RETURNING_CLAUSE`, and related tokens/emitter infrastructure). + +--- + +## File Structure + +``` +include/sql_parser/ + update_parser.h — (create) UPDATE statement parser (header-only template) + emitter.h — (modify) add UPDATE emit methods + common.h — (modify) add NODE_UPDATE_STMT, NODE_UPDATE_SET_CLAUSE + parser.h — (modify) add parse_update() declaration, remove extract_update + +src/sql_parser/ + parser.cpp — (modify) replace extract_update with parse_update() + +tests/ + test_update.cpp — (create) UPDATE parser tests + +Makefile.new — (modify) add test_update.cpp to TEST_SRCS +``` + +--- + +### Task 1: Add UPDATE Node Types + +**Files:** +- Modify: `include/sql_parser/common.h` + +- [ ] **Step 1: Add new node types** + +Add to `NodeType` enum (if not already added by Plan 7): +```cpp +// UPDATE nodes +NODE_UPDATE_STMT, +NODE_UPDATE_SET_CLAUSE, // SET col=expr, col=expr in UPDATE context +``` + +Note: `NODE_UPDATE_SET_ITEM`, `NODE_STMT_OPTIONS`, and `NODE_RETURNING_CLAUSE` are already added by Plan 7. + +--- + +### Task 2: UPDATE Parser Implementation + +**Files:** +- Create: `include/sql_parser/update_parser.h` +- Create: `tests/test_update.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write tests for UPDATE parsing** + +Create `tests/test_update.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLUpdateTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic UPDATE ========== + +TEST_F(MySQLUpdateTest, SimpleUpdate) { + auto r = parser.parse("UPDATE users SET name = 'Alice' WHERE id = 1", 45); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_UPDATE_STMT); +} + +TEST_F(MySQLUpdateTest, UpdateMultipleColumns) { + const char* sql = "UPDATE users SET name = 'Alice', email = 'a@b.com' WHERE id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* set_clause = find_child(r.ast, NodeType::NODE_UPDATE_SET_CLAUSE); + ASSERT_NE(set_clause, nullptr); + EXPECT_EQ(child_count(set_clause), 2); +} + +TEST_F(MySQLUpdateTest, UpdateNoWhere) { + auto r = parser.parse("UPDATE users SET active = 0", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + EXPECT_EQ(where, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateQualifiedTable) { + auto r = parser.parse("UPDATE mydb.users SET name = 'x' WHERE id = 1", 46); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLUpdateTest, UpdateLowPriority) { + auto r = parser.parse("UPDATE LOW_PRIORITY users SET name = 'x' WHERE id = 1", 54); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateIgnore) { + auto r = parser.parse("UPDATE IGNORE users SET name = 'x' WHERE id = 1", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateLowPriorityIgnore) { + auto r = parser.parse("UPDATE LOW_PRIORITY IGNORE users SET name = 'x' WHERE id = 1", 61); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL ORDER BY + LIMIT ========== + +TEST_F(MySQLUpdateTest, UpdateOrderByLimit) { + const char* sql = "UPDATE users SET rank = rank + 1 WHERE active = 1 ORDER BY score DESC LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateLimit) { + auto r = parser.parse("UPDATE users SET active = 0 LIMIT 100", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== MySQL Multi-Table UPDATE ========== + +TEST_F(MySQLUpdateTest, MultiTableJoin) { + const char* sql = "UPDATE users u JOIN orders o ON u.id = o.user_id " + "SET u.total = u.total + o.amount WHERE o.status = 'shipped'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, MultiTableCommaJoin) { + const char* sql = "UPDATE users, orders SET users.total = orders.amount " + "WHERE users.id = orders.user_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, MultiTableLeftJoin) { + const char* sql = "UPDATE users u LEFT JOIN orders o ON u.id = o.user_id " + "SET u.has_orders = 0 WHERE o.id IS NULL"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== PostgreSQL UPDATE ========== + +class PgSQLUpdateTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLUpdateTest, SimpleUpdate) { + auto r = parser.parse("UPDATE users SET name = 'Alice' WHERE id = 1", 45); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFrom) { + const char* sql = "UPDATE users SET total = orders.amount FROM orders WHERE users.id = orders.user_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFromMultipleTables) { + const char* sql = "UPDATE users SET total = o.amount " + "FROM orders o, payments p " + "WHERE users.id = o.user_id AND o.id = p.order_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateReturning) { + const char* sql = "UPDATE users SET name = 'Alice' WHERE id = 1 RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLUpdateTest, UpdateReturningStar) { + const char* sql = "UPDATE users SET name = 'Alice' WHERE id = 1 RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFromReturning) { + const char* sql = "UPDATE users SET total = o.amount FROM orders o " + "WHERE users.id = o.user_id RETURNING users.id, users.total"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_FROM_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateWithAlias) { + const char* sql = "UPDATE users AS u SET name = 'Alice' WHERE u.id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct UpdateTestCase { + const char* sql; + const char* description; +}; + +static const UpdateTestCase mysql_update_bulk_cases[] = { + {"UPDATE t SET a = 1", "simple no where"}, + {"UPDATE t SET a = 1 WHERE b = 2", "simple with where"}, + {"UPDATE t SET a = 1, b = 2 WHERE c = 3", "multi column"}, + {"UPDATE t SET a = a + 1 WHERE b > 0", "expression value"}, + {"UPDATE t SET a = 'hello' WHERE b = 1", "string value"}, + {"UPDATE db.t SET a = 1", "qualified table"}, + {"UPDATE LOW_PRIORITY t SET a = 1", "low priority"}, + {"UPDATE IGNORE t SET a = 1", "ignore"}, + {"UPDATE LOW_PRIORITY IGNORE t SET a = 1", "low priority ignore"}, + {"UPDATE t SET a = 1 ORDER BY b LIMIT 10", "order by limit"}, + {"UPDATE t SET a = 1 LIMIT 100", "limit only"}, + {"UPDATE t1 JOIN t2 ON t1.id = t2.fk SET t1.a = t2.b", "join update"}, + {"UPDATE t1, t2 SET t1.a = t2.b WHERE t1.id = t2.fk", "comma join update"}, + {"UPDATE t1 LEFT JOIN t2 ON t1.id = t2.fk SET t1.a = 0 WHERE t2.id IS NULL", "left join"}, + {"UPDATE t SET a = NOW()", "function in value"}, + {"UPDATE t SET a = NULL WHERE b = 1", "set null"}, + {"UPDATE t SET a = CASE WHEN b > 0 THEN 1 ELSE 0 END", "case expression"}, +}; + +TEST(MySQLUpdateBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_update_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::UPDATE) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const UpdateTestCase pgsql_update_bulk_cases[] = { + {"UPDATE t SET a = 1", "simple no where"}, + {"UPDATE t SET a = 1 WHERE b = 2", "simple with where"}, + {"UPDATE t SET a = 1, b = 2 WHERE c = 3", "multi column"}, + {"UPDATE t AS x SET a = 1 WHERE x.b = 2", "alias"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk", "from clause"}, + {"UPDATE t SET a = t2.b FROM t2, t3 WHERE t.id = t2.fk AND t2.id = t3.fk", "from multi"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING *", "returning star"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING a, b", "returning cols"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk RETURNING t.a", "from + returning"}, +}; + +TEST(PgSQLUpdateBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_update_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::UPDATE) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const UpdateTestCase mysql_update_roundtrip_cases[] = { + {"UPDATE t SET a = 1 WHERE b = 2", "simple"}, + {"UPDATE t SET a = 1, b = 'x' WHERE c = 3", "multi col"}, + {"UPDATE LOW_PRIORITY IGNORE t SET a = 1", "options"}, + {"UPDATE t SET a = 1 ORDER BY b DESC LIMIT 10", "order by limit"}, +}; + +TEST(MySQLUpdateRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_update_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const UpdateTestCase pgsql_update_roundtrip_cases[] = { + {"UPDATE t SET a = 1 WHERE b = 2", "simple"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk", "from clause"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING *", "returning"}, +}; + +TEST(PgSQLUpdateRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_update_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} +``` + +- [ ] **Step 2: Add test_update.cpp to Makefile.new** + +Add `$(TEST_DIR)/test_update.cpp \` to the `TEST_SRCS` list. + +- [ ] **Step 3: Implement UpdateParser class** + +Create `include/sql_parser/update_parser.h`: +```cpp +#ifndef SQL_PARSER_UPDATE_PARSER_H +#define SQL_PARSER_UPDATE_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" + +namespace sql_parser { + +template +class UpdateParser { +public: + UpdateParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_) {} + + // Parse UPDATE statement (UPDATE keyword already consumed). + AstNode* parse(); + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + + // Parse MySQL options: LOW_PRIORITY, IGNORE + AstNode* parse_stmt_options(); + + // Parse SET clause: SET col=expr [, col=expr ...] + AstNode* parse_update_set_clause(); + + // Parse a single col=expr pair + AstNode* parse_set_item(); + + // Parse WHERE clause + AstNode* parse_where_clause(); + + // Parse ORDER BY clause (MySQL single-table only) + AstNode* parse_order_by(); + + // Parse LIMIT clause (MySQL single-table only) + AstNode* parse_limit(); + + // Parse PostgreSQL FROM clause (after SET, before WHERE) + AstNode* parse_from_clause(); + + // Parse PostgreSQL RETURNING clause + AstNode* parse_returning(); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_UPDATE_PARSER_H +``` + +- [ ] **Step 4: Implement UpdateParser parse methods** + +Implement all methods in the header. Key logic for `parse()`: + +**MySQL flow:** +1. Parse options (LOW_PRIORITY, IGNORE) +2. Parse table references using `table_ref_parser_` (supports JOINs for multi-table) +3. Expect and consume `SET` keyword +4. Parse SET assignments +5. Parse optional WHERE +6. Parse optional ORDER BY (single-table only) +7. Parse optional LIMIT (single-table only) + +**PostgreSQL flow:** +1. Parse optional ONLY keyword +2. Parse single table reference with optional alias +3. Expect and consume `SET` keyword +4. Parse SET assignments +5. Parse optional FROM clause (additional table sources) +6. Parse optional WHERE +7. Parse optional RETURNING + +Use `if constexpr (D == Dialect::MySQL)` and `if constexpr (D == Dialect::PostgreSQL)` for dialect-specific branches. + +Refer to `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` for full AST structure details. + +--- + +### Task 3: Emitter Support for UPDATE Nodes + +**Files:** +- Modify: `include/sql_parser/emitter.h` + +- [ ] **Step 1: Add emit methods for UPDATE node types** + +Add cases to the `emit_node()` switch: +```cpp +case NodeType::NODE_UPDATE_STMT: emit_update_stmt(node); break; +case NodeType::NODE_UPDATE_SET_CLAUSE: emit_update_set_clause(node); break; +``` + +Note: `NODE_UPDATE_SET_ITEM`, `NODE_STMT_OPTIONS`, `NODE_RETURNING_CLAUSE` emitters were added in Plan 7. + +- [ ] **Step 2: Implement emit_update_stmt and emit_update_set_clause** + +```cpp +void emit_update_stmt(const AstNode* node); +void emit_update_set_clause(const AstNode* node); +``` + +`emit_update_stmt()` must handle the different child ordering between MySQL (table refs before SET) and PostgreSQL (single table, SET, then optional FROM). The emitter walks the children and emits them in the correct order based on their node types. + +--- + +### Task 4: Classifier Integration + +**Files:** +- Modify: `include/sql_parser/parser.h` +- Modify: `src/sql_parser/parser.cpp` + +- [ ] **Step 1: Add `parse_update()` method declaration** + +Add to the private section of `Parser`: +```cpp +ParseResult parse_update(); +``` + +- [ ] **Step 2: Implement `parse_update()` in `parser.cpp`** + +Follow the same pattern as `parse_insert()`: +```cpp +template +ParseResult Parser::parse_update() { + ParseResult r; + r.stmt_type = StmtType::UPDATE; + + UpdateParser update_parser(tokenizer_, arena_); + AstNode* ast = update_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 3: Update `classify_and_dispatch()` switch** + +Replace: +```cpp +case TokenType::TK_UPDATE: return extract_update(first); +``` +With: +```cpp +case TokenType::TK_UPDATE: return parse_update(); +``` + +- [ ] **Step 4: Add `#include "sql_parser/update_parser.h"` to `parser.cpp`** + +- [ ] **Step 5: Run all tests** + +```bash +make -f Makefile.new test +``` + +All existing tests plus new UPDATE tests should pass. diff --git a/docs/superpowers/plans/2026-03-24-plan9-delete-parser.md b/docs/superpowers/plans/2026-03-24-plan9-delete-parser.md new file mode 100644 index 0000000..ee51dce --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-plan9-delete-parser.md @@ -0,0 +1,592 @@ +# DELETE Deep Parser Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build a full DELETE deep parser for MySQL and PostgreSQL that handles single-table delete, MySQL multi-table (both forms), PostgreSQL USING, ORDER BY/LIMIT (MySQL), and RETURNING (PostgreSQL). + +**Architecture:** `DeleteParser` is a header-only template class following the established pattern. It uses `TableRefParser` (from Plan 7) for table reference parsing and `ExpressionParser` for expressions. MySQL multi-table DELETE has two forms that require disambiguation: form 1 (`DELETE t1, t2 FROM ...`) and form 2 (`DELETE FROM t1, t2 USING ...`). The parser disambiguates by checking whether `FROM` appears immediately after options/targets or whether a table list precedes it. + +**Tech Stack:** C++17, existing parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` + +--- + +## Scope + +This plan builds: +1. `DeleteParser` — full DELETE deep parser for both dialects +2. Emitter extensions for DELETE-specific node types +3. Classifier update to route TK_DELETE to the deep parser +4. Comprehensive tests including round-trip tests + +**Closes:** #7 + +**Dependencies:** Plan 7 (provides `TableRefParser`, `NODE_STMT_OPTIONS`, `NODE_RETURNING_CLAUSE`, and related infrastructure). Plan 8 is independent of this plan. + +--- + +## File Structure + +``` +include/sql_parser/ + delete_parser.h — (create) DELETE statement parser (header-only template) + emitter.h — (modify) add DELETE emit methods + common.h — (modify) add NODE_DELETE_STMT, NODE_DELETE_USING_CLAUSE + +src/sql_parser/ + parser.cpp — (modify) replace extract_delete with parse_delete() + +include/sql_parser/ + parser.h — (modify) add parse_delete() declaration + +tests/ + test_delete.cpp — (create) DELETE parser tests + +Makefile.new — (modify) add test_delete.cpp to TEST_SRCS +``` + +--- + +### Task 1: Add DELETE Node Types + +**Files:** +- Modify: `include/sql_parser/common.h` + +- [ ] **Step 1: Add new node types** + +Add to `NodeType` enum (if not already present from Plan 7/8): +```cpp +// DELETE nodes +NODE_DELETE_STMT, +NODE_DELETE_USING_CLAUSE, // PostgreSQL USING or MySQL USING form +``` + +Note: `NODE_STMT_OPTIONS`, `NODE_RETURNING_CLAUSE`, `NODE_FROM_CLAUSE`, `NODE_WHERE_CLAUSE`, `NODE_ORDER_BY_CLAUSE`, `NODE_LIMIT_CLAUSE` already exist. + +--- + +### Task 2: DELETE Parser Implementation + +**Files:** +- Create: `include/sql_parser/delete_parser.h` +- Create: `tests/test_delete.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write tests for DELETE parsing** + +Create `tests/test_delete.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLDeleteTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic DELETE ========== + +TEST_F(MySQLDeleteTest, SimpleDelete) { + auto r = parser.parse("DELETE FROM users WHERE id = 1", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_DELETE_STMT); +} + +TEST_F(MySQLDeleteTest, DeleteNoWhere) { + auto r = parser.parse("DELETE FROM users", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + EXPECT_EQ(where, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteQualifiedTable) { + auto r = parser.parse("DELETE FROM mydb.users WHERE id = 1", 35); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteComplexWhere) { + const char* sql = "DELETE FROM users WHERE status = 'inactive' AND last_login < '2020-01-01'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLDeleteTest, DeleteLowPriority) { + auto r = parser.parse("DELETE LOW_PRIORITY FROM users WHERE id = 1", 44); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteQuick) { + auto r = parser.parse("DELETE QUICK FROM users WHERE id = 1", 36); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteIgnore) { + auto r = parser.parse("DELETE IGNORE FROM users WHERE id = 1", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteAllOptions) { + auto r = parser.parse("DELETE LOW_PRIORITY QUICK IGNORE FROM users WHERE id = 1", 56); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL ORDER BY + LIMIT ========== + +TEST_F(MySQLDeleteTest, DeleteOrderByLimit) { + const char* sql = "DELETE FROM users WHERE active = 0 ORDER BY created_at ASC LIMIT 100"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteLimitOnly) { + auto r = parser.parse("DELETE FROM users WHERE active = 0 LIMIT 100", 45); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== MySQL Multi-Table Form 1: DELETE t1, t2 FROM ... ========== + +TEST_F(MySQLDeleteTest, MultiTableForm1Single) { + const char* sql = "DELETE t1 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t2.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, MultiTableForm1Multiple) { + const char* sql = "DELETE t1, t2 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t1.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Multi-Table Form 2: DELETE FROM t1, t2 USING ... ========== + +TEST_F(MySQLDeleteTest, MultiTableForm2) { + const char* sql = "DELETE FROM t1, t2 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t1.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, MultiTableForm2Single) { + const char* sql = "DELETE FROM t1 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t2.bad = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== PostgreSQL DELETE ========== + +class PgSQLDeleteTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLDeleteTest, SimpleDelete) { + auto r = parser.parse("DELETE FROM users WHERE id = 1", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteUsing) { + const char* sql = "DELETE FROM users USING orders WHERE users.id = orders.user_id AND orders.bad = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* using_clause = find_child(r.ast, NodeType::NODE_DELETE_USING_CLAUSE); + ASSERT_NE(using_clause, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteUsingMultiple) { + const char* sql = "DELETE FROM t1 USING t2, t3 WHERE t1.id = t2.fk AND t2.id = t3.fk"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteReturning) { + const char* sql = "DELETE FROM users WHERE id = 1 RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteReturningColumns) { + const char* sql = "DELETE FROM users WHERE id = 1 RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLDeleteTest, DeleteUsingReturning) { + const char* sql = "DELETE FROM users USING orders " + "WHERE users.id = orders.user_id RETURNING users.id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_DELETE_USING_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteWithAlias) { + const char* sql = "DELETE FROM users AS u WHERE u.id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct DeleteTestCase { + const char* sql; + const char* description; +}; + +static const DeleteTestCase mysql_delete_bulk_cases[] = { + {"DELETE FROM t", "simple no where"}, + {"DELETE FROM t WHERE a = 1", "simple with where"}, + {"DELETE FROM t WHERE a > 1 AND b < 10", "complex where"}, + {"DELETE FROM db.t WHERE a = 1", "qualified table"}, + {"DELETE LOW_PRIORITY FROM t WHERE a = 1", "low priority"}, + {"DELETE QUICK FROM t WHERE a = 1", "quick"}, + {"DELETE IGNORE FROM t WHERE a = 1", "ignore"}, + {"DELETE LOW_PRIORITY QUICK IGNORE FROM t WHERE a = 1", "all options"}, + {"DELETE FROM t WHERE a = 1 ORDER BY b LIMIT 10", "order by limit"}, + {"DELETE FROM t WHERE a = 1 LIMIT 100", "limit only"}, + {"DELETE t1 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t2.x = 0", "multi-table form 1"}, + {"DELETE t1, t2 FROM t1 JOIN t2 ON t1.id = t2.fk", "multi-table form 1 multi target"}, + {"DELETE FROM t1 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t2.x = 0", "multi-table form 2"}, + {"DELETE FROM t1, t2 USING t1 JOIN t2 ON t1.id = t2.fk", "multi-table form 2 multi target"}, +}; + +TEST(MySQLDeleteBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_delete_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const DeleteTestCase pgsql_delete_bulk_cases[] = { + {"DELETE FROM t", "simple no where"}, + {"DELETE FROM t WHERE a = 1", "simple with where"}, + {"DELETE FROM t WHERE a > 1 AND b < 10", "complex where"}, + {"DELETE FROM t AS x WHERE x.a = 1", "alias"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk", "using single"}, + {"DELETE FROM t USING t2, t3 WHERE t.id = t2.fk AND t2.id = t3.fk", "using multi"}, + {"DELETE FROM t WHERE a = 1 RETURNING *", "returning star"}, + {"DELETE FROM t WHERE a = 1 RETURNING a, b", "returning cols"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk RETURNING t.a", "using + returning"}, +}; + +TEST(PgSQLDeleteBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_delete_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const DeleteTestCase mysql_delete_roundtrip_cases[] = { + {"DELETE FROM t WHERE a = 1", "simple"}, + {"DELETE LOW_PRIORITY QUICK IGNORE FROM t WHERE a = 1", "all options"}, + {"DELETE FROM t WHERE a = 1 ORDER BY b LIMIT 10", "order by limit"}, +}; + +TEST(MySQLDeleteRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_delete_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const DeleteTestCase pgsql_delete_roundtrip_cases[] = { + {"DELETE FROM t WHERE a = 1", "simple"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk", "using"}, + {"DELETE FROM t WHERE a = 1 RETURNING *", "returning"}, +}; + +TEST(PgSQLDeleteRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_delete_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} +``` + +- [ ] **Step 2: Add test_delete.cpp to Makefile.new** + +Add `$(TEST_DIR)/test_delete.cpp \` to the `TEST_SRCS` list. + +- [ ] **Step 3: Implement DeleteParser class** + +Create `include/sql_parser/delete_parser.h`: +```cpp +#ifndef SQL_PARSER_DELETE_PARSER_H +#define SQL_PARSER_DELETE_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" + +namespace sql_parser { + +// Flags on NODE_DELETE_STMT +static constexpr uint16_t FLAG_DELETE_MULTI_TABLE = 0x01; // multi-table form +static constexpr uint16_t FLAG_DELETE_FORM2 = 0x02; // MySQL form 2 (DELETE FROM ... USING) + +template +class DeleteParser { +public: + DeleteParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_) {} + + // Parse DELETE statement (DELETE keyword already consumed). + AstNode* parse(); + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + + // Parse MySQL options: LOW_PRIORITY, QUICK, IGNORE + AstNode* parse_stmt_options(); + + // Parse target table list for multi-table: t1 [, t2, ...] + AstNode* parse_target_tables(); + + // Parse WHERE clause + AstNode* parse_where_clause(); + + // Parse ORDER BY (MySQL single-table) + AstNode* parse_order_by(); + + // Parse LIMIT (MySQL single-table) + AstNode* parse_limit(); + + // Parse PostgreSQL USING clause + AstNode* parse_using_clause(); + + // Parse PostgreSQL RETURNING clause + AstNode* parse_returning(); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_DELETE_PARSER_H +``` + +- [ ] **Step 4: Implement DeleteParser parse methods** + +The key complexity is MySQL multi-table disambiguation in `parse()`: + +**MySQL flow:** +1. Parse options (LOW_PRIORITY, QUICK, IGNORE) +2. Check if next token is `FROM`: + - **Yes**: could be single-table OR form 2. Consume FROM, parse table list. + - If next token is `USING`: this is **form 2**. Consume USING, parse table references (source tables with JOINs). + - Otherwise: single-table delete. Parse optional WHERE, ORDER BY, LIMIT. + - **No**: this is **form 1**. Parse target table list (t1, t2, ...), then expect FROM, parse table references (source tables with JOINs), then WHERE. + +**PostgreSQL flow:** +1. Consume FROM +2. Parse optional ONLY keyword +3. Parse single table reference with optional alias +4. Parse optional USING clause +5. Parse optional WHERE +6. Parse optional RETURNING + +Refer to `docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md` for full syntax and AST structure. + +--- + +### Task 3: Emitter Support for DELETE Nodes + +**Files:** +- Modify: `include/sql_parser/emitter.h` + +- [ ] **Step 1: Add emit methods for DELETE node types** + +Add cases to the `emit_node()` switch: +```cpp +case NodeType::NODE_DELETE_STMT: emit_delete_stmt(node); break; +case NodeType::NODE_DELETE_USING_CLAUSE: emit_delete_using(node); break; +``` + +- [ ] **Step 2: Implement emit_delete_stmt and emit_delete_using** + +```cpp +void emit_delete_stmt(const AstNode* node); +void emit_delete_using(const AstNode* node); +``` + +`emit_delete_stmt()` must check `flags` to determine the delete form: +- Single-table: `DELETE [options] FROM table [WHERE] [ORDER BY] [LIMIT]` +- Multi-table form 1: `DELETE [options] targets FROM table_refs [WHERE]` +- Multi-table form 2: `DELETE [options] FROM targets USING table_refs [WHERE]` +- PostgreSQL: `DELETE FROM table [USING] [WHERE] [RETURNING]` + +--- + +### Task 4: Classifier Integration + +**Files:** +- Modify: `include/sql_parser/parser.h` +- Modify: `src/sql_parser/parser.cpp` + +- [ ] **Step 1: Add `parse_delete()` declaration** + +Add to the private section of `Parser`: +```cpp +ParseResult parse_delete(); +``` + +- [ ] **Step 2: Implement `parse_delete()` in `parser.cpp`** + +```cpp +template +ParseResult Parser::parse_delete() { + ParseResult r; + r.stmt_type = StmtType::DELETE_STMT; + + DeleteParser delete_parser(tokenizer_, arena_); + AstNode* ast = delete_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 3: Update `classify_and_dispatch()` switch** + +Replace: +```cpp +case TokenType::TK_DELETE: return extract_delete(first); +``` +With: +```cpp +case TokenType::TK_DELETE: return parse_delete(); +``` + +- [ ] **Step 4: Add `#include "sql_parser/delete_parser.h"` to `parser.cpp`** + +- [ ] **Step 5: Run all tests** + +```bash +make -f Makefile.new test +``` + +All existing tests plus new DELETE tests should pass. diff --git a/docs/superpowers/plans/2026-03-24-prepared-stmt-cache.md b/docs/superpowers/plans/2026-03-24-prepared-stmt-cache.md new file mode 100644 index 0000000..10a950a --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-prepared-stmt-cache.md @@ -0,0 +1,693 @@ +# Prepared Statement Cache Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add prepared statement cache and binary protocol support — `parse_and_cache()`, `execute()`, and `prepare_cache_evict()` — so ProxySQL can handle COM_STMT_PREPARE/EXECUTE efficiently. + +**Architecture:** The statement cache is a fixed-capacity LRU map (keyed by statement ID) that stores deep-copied ASTs outside the arena. `parse_and_cache()` parses normally, then copies the AST to the cache. `execute()` looks up the cached AST and returns it with bound parameter bindings. The emitter is extended to materialize placeholders from bindings. + +**Tech Stack:** C++17, existing parser/arena/emitter infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` (Binary Protocol section) + +--- + +## Scope + +1. `BoundValue` and `ParamBindings` structs (in `parse_result.h`) +2. AST deep-copy utility (copy AST tree from arena to heap for caching) +3. `StmtCache` — fixed-capacity LRU cache keyed by statement ID +4. `parse_and_cache()`, `execute()`, `prepare_cache_evict()` on `Parser` +5. Emitter extension — materialize `NODE_PLACEHOLDER` from bindings +6. Tests for the full prepare/execute/evict lifecycle + +**Not in scope:** Actual MySQL wire protocol decoding (ProxySQL handles that). + +--- + +## File Structure + +``` +include/sql_parser/ + parse_result.h — (modify) Add BoundValue, ParamBindings + stmt_cache.h — StmtCache class + AST deep-copy + emitter.h — (modify) Add bindings-aware placeholder emission + parser.h — (modify) Add parse_and_cache, execute, prepare_cache_evict + +src/sql_parser/ + parser.cpp — (modify) Implement new methods + +tests/ + test_stmt_cache.cpp — Prepared statement lifecycle tests + +Makefile.new — (modify) Add test_stmt_cache.cpp +``` + +--- + +### Task 1: BoundValue, ParamBindings, StmtCache, AST Deep-Copy + +**Files:** +- Modify: `include/sql_parser/parse_result.h` — add BoundValue, ParamBindings +- Create: `include/sql_parser/stmt_cache.h` — StmtCache + AST deep-copy +- Create: `tests/test_stmt_cache.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Add BoundValue and ParamBindings to parse_result.h** + +Add to `include/sql_parser/parse_result.h` before `ParseResult`: +```cpp +struct BoundValue { + enum Type : uint8_t { INT, FLOAT, DOUBLE, STRING, BLOB, NULL_VAL, DATETIME, DECIMAL }; + Type type = NULL_VAL; + union { + int64_t int_val; + float float32_val; + double float64_val; + StringRef str_val; + }; +}; +static_assert(std::is_trivially_copyable_v); + +struct ParamBindings { + BoundValue* values = nullptr; + uint16_t count = 0; +}; +``` + +Also add a `ParamBindings` field to `ParseResult`: +```cpp + ParamBindings bindings; // populated by execute() +``` + +- [ ] **Step 2: Create stmt_cache.h** + +Create `include/sql_parser/stmt_cache.h`: +```cpp +#ifndef SQL_PARSER_STMT_CACHE_H +#define SQL_PARSER_STMT_CACHE_H + +#include "sql_parser/ast.h" +#include "sql_parser/common.h" +#include "sql_parser/parse_result.h" +#include +#include +#include +#include + +namespace sql_parser { + +// Deep-copy an AST tree from arena to heap memory. +// The returned tree must be freed with free_ast(). +inline AstNode* deep_copy_ast(const AstNode* src) { + if (!src) return nullptr; + + AstNode* dst = static_cast(std::malloc(sizeof(AstNode))); + if (!dst) return nullptr; + + dst->type = src->type; + dst->flags = src->flags; + dst->first_child = nullptr; + dst->next_sibling = nullptr; + + // Deep-copy value string to heap + if (src->value_ptr && src->value_len > 0) { + char* val_copy = static_cast(std::malloc(src->value_len)); + if (val_copy) { + std::memcpy(val_copy, src->value_ptr, src->value_len); + } + dst->value_ptr = val_copy; + dst->value_len = src->value_len; + } else { + dst->value_ptr = nullptr; + dst->value_len = 0; + } + + // Recursively copy children + const AstNode* src_child = src->first_child; + AstNode* prev_dst_child = nullptr; + while (src_child) { + AstNode* dst_child = deep_copy_ast(src_child); + if (dst_child) { + if (!dst->first_child) { + dst->first_child = dst_child; + } else if (prev_dst_child) { + prev_dst_child->next_sibling = dst_child; + } + prev_dst_child = dst_child; + } + src_child = src_child->next_sibling; + } + + return dst; +} + +// Free a heap-allocated AST tree (produced by deep_copy_ast). +inline void free_ast(AstNode* node) { + if (!node) return; + // Free children first + AstNode* child = node->first_child; + while (child) { + AstNode* next = child->next_sibling; + free_ast(child); + child = next; + } + // Free value string + if (node->value_ptr) { + std::free(const_cast(node->value_ptr)); + } + std::free(node); +} + +// Cached entry for a prepared statement. +struct CachedStmt { + uint32_t stmt_id; + StmtType stmt_type; + AstNode* ast; // heap-allocated deep copy + + ~CachedStmt() { + free_ast(ast); + } + + // Non-copyable + CachedStmt(const CachedStmt&) = delete; + CachedStmt& operator=(const CachedStmt&) = delete; + CachedStmt(CachedStmt&& o) noexcept + : stmt_id(o.stmt_id), stmt_type(o.stmt_type), ast(o.ast) { + o.ast = nullptr; + } + CachedStmt& operator=(CachedStmt&& o) noexcept { + if (this != &o) { + free_ast(ast); + stmt_id = o.stmt_id; + stmt_type = o.stmt_type; + ast = o.ast; + o.ast = nullptr; + } + return *this; + } + + CachedStmt() : stmt_id(0), stmt_type(StmtType::UNKNOWN), ast(nullptr) {} + CachedStmt(uint32_t id, StmtType type, AstNode* a) + : stmt_id(id), stmt_type(type), ast(a) {} +}; + +// Fixed-capacity LRU cache for prepared statements. +class StmtCache { +public: + explicit StmtCache(size_t capacity = 128) : capacity_(capacity) {} + + ~StmtCache() { clear(); } + + // Non-copyable + StmtCache(const StmtCache&) = delete; + StmtCache& operator=(const StmtCache&) = delete; + + // Store a prepared statement. Deep-copies the AST from the arena. + // Evicts LRU entry if at capacity. + bool store(uint32_t stmt_id, StmtType stmt_type, const AstNode* ast) { + // If already exists, remove old entry + evict(stmt_id); + + AstNode* copy = deep_copy_ast(ast); + if (!copy && ast) return false; + + // Evict LRU if at capacity + if (lru_.size() >= capacity_) { + auto& oldest = lru_.back(); + map_.erase(oldest.stmt_id); + lru_.pop_back(); + } + + lru_.emplace_front(stmt_id, stmt_type, copy); + map_[stmt_id] = lru_.begin(); + return true; + } + + // Look up a cached statement. Returns nullptr if not found. + // Moves the entry to front of LRU. + const CachedStmt* lookup(uint32_t stmt_id) { + auto it = map_.find(stmt_id); + if (it == map_.end()) return nullptr; + // Move to front (most recently used) + lru_.splice(lru_.begin(), lru_, it->second); + return &(*it->second); + } + + // Evict a specific statement. + void evict(uint32_t stmt_id) { + auto it = map_.find(stmt_id); + if (it != map_.end()) { + lru_.erase(it->second); + map_.erase(it); + } + } + + // Clear all entries. + void clear() { + lru_.clear(); + map_.clear(); + } + + size_t size() const { return map_.size(); } + size_t capacity() const { return capacity_; } + +private: + size_t capacity_; + std::list lru_; + std::unordered_map::iterator> map_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_STMT_CACHE_H +``` + +- [ ] **Step 3: Write tests** + +Create `tests/test_stmt_cache.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/stmt_cache.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +// ========== StmtCache unit tests ========== + +TEST(StmtCacheTest, StoreAndLookup) { + StmtCache cache(16); + Arena arena(4096); + + AstNode* node = make_node(arena, NodeType::NODE_SET_STMT, StringRef{"SET", 3}); + ASSERT_NE(node, nullptr); + + EXPECT_TRUE(cache.store(1, StmtType::SET, node)); + EXPECT_EQ(cache.size(), 1u); + + const CachedStmt* found = cache.lookup(1); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->stmt_id, 1u); + EXPECT_EQ(found->stmt_type, StmtType::SET); + ASSERT_NE(found->ast, nullptr); + EXPECT_EQ(found->ast->type, NodeType::NODE_SET_STMT); +} + +TEST(StmtCacheTest, LookupMiss) { + StmtCache cache(16); + EXPECT_EQ(cache.lookup(999), nullptr); +} + +TEST(StmtCacheTest, Evict) { + StmtCache cache(16); + Arena arena(4096); + + AstNode* node = make_node(arena, NodeType::NODE_SELECT_STMT); + cache.store(1, StmtType::SELECT, node); + EXPECT_EQ(cache.size(), 1u); + + cache.evict(1); + EXPECT_EQ(cache.size(), 0u); + EXPECT_EQ(cache.lookup(1), nullptr); +} + +TEST(StmtCacheTest, LRUEviction) { + StmtCache cache(2); // capacity = 2 + Arena arena(4096); + + AstNode* n1 = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* n2 = make_node(arena, NodeType::NODE_SELECT_STMT); + AstNode* n3 = make_node(arena, NodeType::NODE_SET_STMT); + + cache.store(1, StmtType::SET, n1); + cache.store(2, StmtType::SELECT, n2); + EXPECT_EQ(cache.size(), 2u); + + // Adding a third should evict the LRU (stmt 1) + cache.store(3, StmtType::SET, n3); + EXPECT_EQ(cache.size(), 2u); + EXPECT_EQ(cache.lookup(1), nullptr); // evicted + EXPECT_NE(cache.lookup(2), nullptr); // still there + EXPECT_NE(cache.lookup(3), nullptr); // just added +} + +TEST(StmtCacheTest, LRUTouchOnLookup) { + StmtCache cache(2); + Arena arena(4096); + + AstNode* n1 = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* n2 = make_node(arena, NodeType::NODE_SELECT_STMT); + AstNode* n3 = make_node(arena, NodeType::NODE_SET_STMT); + + cache.store(1, StmtType::SET, n1); + cache.store(2, StmtType::SELECT, n2); + + // Touch stmt 1 to make it recently used + cache.lookup(1); + + // Adding stmt 3 should evict stmt 2 (now the LRU) + cache.store(3, StmtType::SET, n3); + EXPECT_NE(cache.lookup(1), nullptr); // touched, still alive + EXPECT_EQ(cache.lookup(2), nullptr); // evicted + EXPECT_NE(cache.lookup(3), nullptr); +} + +TEST(StmtCacheTest, DeepCopyPreservesTree) { + Arena arena(4096); + + // Build a small tree: SET_STMT -> VAR_ASSIGNMENT -> (VAR_TARGET, LITERAL_INT) + AstNode* root = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* assign = make_node(arena, NodeType::NODE_VAR_ASSIGNMENT); + AstNode* target = make_node(arena, NodeType::NODE_VAR_TARGET); + target->add_child(make_node(arena, NodeType::NODE_IDENTIFIER, StringRef{"autocommit", 10})); + AstNode* value = make_node(arena, NodeType::NODE_LITERAL_INT, StringRef{"1", 1}); + assign->add_child(target); + assign->add_child(value); + root->add_child(assign); + + // Deep copy + AstNode* copy = deep_copy_ast(root); + ASSERT_NE(copy, nullptr); + EXPECT_EQ(copy->type, NodeType::NODE_SET_STMT); + + // Verify tree structure is preserved + ASSERT_NE(copy->first_child, nullptr); + EXPECT_EQ(copy->first_child->type, NodeType::NODE_VAR_ASSIGNMENT); + + AstNode* copy_target = copy->first_child->first_child; + ASSERT_NE(copy_target, nullptr); + EXPECT_EQ(copy_target->type, NodeType::NODE_VAR_TARGET); + + AstNode* copy_name = copy_target->first_child; + ASSERT_NE(copy_name, nullptr); + EXPECT_EQ(std::string(copy_name->value_ptr, copy_name->value_len), "autocommit"); + + // Verify it's a deep copy (different pointers) + EXPECT_NE(copy, root); + EXPECT_NE(copy->first_child, root->first_child); + EXPECT_NE(copy_name->value_ptr, target->first_child->value_ptr); + + // Reset arena — copy should still be valid + arena.reset(); + EXPECT_EQ(std::string(copy_name->value_ptr, copy_name->value_len), "autocommit"); + + free_ast(copy); +} + +// ========== Parser integration tests ========== + +TEST(PreparedStmtTest, ParseAndCache) { + Parser parser; + + auto r = parser.parse_and_cache("SELECT * FROM users WHERE id = ?", 32, 1); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); +} + +TEST(PreparedStmtTest, ExecuteAfterCache) { + Parser parser; + + parser.parse_and_cache("SET autocommit = ?", 18, 42); + + // Build bindings + BoundValue bv; + bv.type = BoundValue::INT; + bv.int_val = 0; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(42, bindings); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.bindings.count, 1); + EXPECT_EQ(r.bindings.values[0].int_val, 0); +} + +TEST(PreparedStmtTest, ExecuteNotFound) { + Parser parser; + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(999, bindings); + EXPECT_EQ(r.status, ParseResult::ERROR); +} + +TEST(PreparedStmtTest, EvictAndExecuteFails) { + Parser parser; + + parser.parse_and_cache("SELECT 1", 8, 10); + parser.prepare_cache_evict(10); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(10, bindings); + EXPECT_EQ(r.status, ParseResult::ERROR); +} + +TEST(PreparedStmtTest, CacheMultipleStatements) { + Parser parser; + + parser.parse_and_cache("SELECT 1", 8, 1); + parser.parse_and_cache("SELECT 2", 8, 2); + parser.parse_and_cache("SET autocommit = 0", 18, 3); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 0}; + + auto r1 = parser.execute(1, bindings); + EXPECT_EQ(r1.status, ParseResult::OK); + EXPECT_EQ(r1.stmt_type, StmtType::SELECT); + + auto r3 = parser.execute(3, bindings); + EXPECT_EQ(r3.status, ParseResult::OK); + EXPECT_EQ(r3.stmt_type, StmtType::SET); +} + +// ========== Emitter with bindings ========== + +TEST(PreparedStmtTest, EmitWithBindings) { + Parser parser; + + parser.parse_and_cache("SET autocommit = ?", 18, 1); + + BoundValue bv; + bv.type = BoundValue::INT; + bv.int_val = 1; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(1, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + + Emitter emitter(parser.arena(), &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET autocommit = 1"); +} + +TEST(PreparedStmtTest, EmitWithStringBinding) { + Parser parser; + + parser.parse_and_cache("SET sql_mode = ?", 16, 2); + + const char* mode = "TRADITIONAL"; + BoundValue bv; + bv.type = BoundValue::STRING; + bv.str_val = StringRef{mode, 11}; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(2, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + + Emitter emitter(parser.arena(), &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET sql_mode = 'TRADITIONAL'"); +} + +TEST(PreparedStmtTest, EmitWithNullBinding) { + Parser parser; + + parser.parse_and_cache("SET character_set_results = ?", 28, 3); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(3, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + + Emitter emitter(parser.arena(), &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET character_set_results = NULL"); +} +``` + +- [ ] **Step 4: Update Parser class — add new methods** + +Modify `include/sql_parser/parser.h` — add declarations after existing public methods: +```cpp + // Prepared statement support + ParseResult parse_and_cache(const char* sql, size_t len, uint32_t stmt_id); + ParseResult execute(uint32_t stmt_id, const ParamBindings& params); + void prepare_cache_evict(uint32_t stmt_id); +``` + +Add private member: +```cpp + StmtCache stmt_cache_; +``` + +Add include at top: +```cpp +#include "sql_parser/stmt_cache.h" +``` + +Update `ParserConfig` with cache capacity: +```cpp +struct ParserConfig { + size_t arena_block_size = 65536; + size_t arena_max_size = 1048576; + size_t stmt_cache_capacity = 128; +}; +``` + +Update constructor to use config: +```cpp + explicit Parser(const ParserConfig& config = {}); +``` + +Modify `src/sql_parser/parser.cpp` — update constructor and add implementations: + +Update constructor: +```cpp +template +Parser::Parser(const ParserConfig& config) + : arena_(config.arena_block_size, config.arena_max_size), + stmt_cache_(config.stmt_cache_capacity) {} +``` + +Add new methods: +```cpp +template +ParseResult Parser::parse_and_cache(const char* sql, size_t len, uint32_t stmt_id) { + ParseResult r = parse(sql, len); + if (r.ast) { + stmt_cache_.store(stmt_id, r.stmt_type, r.ast); + } + return r; +} + +template +ParseResult Parser::execute(uint32_t stmt_id, const ParamBindings& params) { + ParseResult r; + const CachedStmt* cached = stmt_cache_.lookup(stmt_id); + if (!cached) { + r.status = ParseResult::ERROR; + r.stmt_type = StmtType::UNKNOWN; + return r; + } + r.status = ParseResult::OK; + r.stmt_type = cached->stmt_type; + r.ast = cached->ast; + r.bindings = params; + return r; +} + +template +void Parser::prepare_cache_evict(uint32_t stmt_id) { + stmt_cache_.evict(stmt_id); +} +``` + +- [ ] **Step 5: Extend Emitter to handle bindings** + +Modify `include/sql_parser/emitter.h` — add bindings-aware constructor and placeholder emission: + +Update constructor to optionally accept bindings: +```cpp + explicit Emitter(Arena& arena, const ParamBindings* bindings = nullptr) + : sb_(arena), bindings_(bindings), placeholder_index_(0) {} +``` + +Add private member: +```cpp + const ParamBindings* bindings_; + uint16_t placeholder_index_; +``` + +Update the `NODE_PLACEHOLDER` handling in `emit_node()` switch — change from `emit_value` to a new method: +```cpp + case NodeType::NODE_PLACEHOLDER: + emit_placeholder(node); break; +``` + +Add `emit_placeholder`: +```cpp + void emit_placeholder(const AstNode* node) { + if (bindings_ && placeholder_index_ < bindings_->count) { + const BoundValue& bv = bindings_->values[placeholder_index_]; + ++placeholder_index_; + switch (bv.type) { + case BoundValue::INT: + // Convert int to string in arena + { char buf[32]; int n = snprintf(buf, sizeof(buf), "%lld", (long long)bv.int_val); + sb_.append(buf, n); } + break; + case BoundValue::FLOAT: + { char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", (double)bv.float32_val); + sb_.append(buf, n); } + break; + case BoundValue::DOUBLE: + { char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", bv.float64_val); + sb_.append(buf, n); } + break; + case BoundValue::STRING: + case BoundValue::DATETIME: + case BoundValue::DECIMAL: + sb_.append_char('\''); + sb_.append(bv.str_val); + sb_.append_char('\''); + break; + case BoundValue::BLOB: + sb_.append(bv.str_val); + break; + case BoundValue::NULL_VAL: + sb_.append("NULL", 4); + break; + } + } else { + // No binding available — emit placeholder as-is + emit_value(node); + } + } +``` + +- [ ] **Step 6: Update Makefile.new and build** + +Add `$(TEST_DIR)/test_stmt_cache.cpp` to TEST_SRCS. + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` + +- [ ] **Step 7: Commit** + +```bash +git add include/sql_parser/stmt_cache.h include/sql_parser/parse_result.h \ + include/sql_parser/parser.h include/sql_parser/emitter.h \ + src/sql_parser/parser.cpp tests/test_stmt_cache.cpp Makefile.new +git commit -m "feat: add prepared statement cache with parse_and_cache, execute, and bindings-aware emitter" +``` diff --git a/docs/superpowers/plans/2026-03-24-query-emitter.md b/docs/superpowers/plans/2026-03-24-query-emitter.md new file mode 100644 index 0000000..61ce1a7 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-query-emitter.md @@ -0,0 +1,947 @@ +# Query Emitter Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the query emitter that reconstructs valid SQL from an AST, enabling ProxySQL to parse a query, modify the AST, and emit the modified query. + +**Architecture:** The emitter is a dialect-templated class `Emitter` that walks an `AstNode` tree and writes SQL into an arena-backed `StringBuilder`. For unmodified nodes, it emits the original input text via `StringRef` (zero-copy). For modified nodes, it emits the new values. The emitter handles all node types produced by the SET, SELECT, and expression parsers. Round-trip tests (parse → emit → parse → compare) validate correctness. + +**Tech Stack:** C++17, existing arena/AST infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` (Query Reconstruction section) + +--- + +## Scope + +This plan builds: +1. `StringBuilder` — arena-backed string builder for output +2. `Emitter` — walks AST, emits SQL for all node types +3. `emit()` convenience function on `Parser` — public API +4. Round-trip tests for SET and SELECT statements +5. Modification tests — parse, modify AST, emit, verify output + +**Not in scope:** Cross-dialect emission, prepared statement cache. + +--- + +## File Structure + +``` +include/sql_parser/ + string_builder.h — Arena-backed string builder + emitter.h — Emitter template (header-only) + parser.h — (modify) Add emit() method + +tests/ + test_emitter.cpp — Round-trip and modification tests + +Makefile.new — (modify) Add test_emitter.cpp +``` + +--- + +### Task 1: StringBuilder and Emitter — Expression Nodes + +**Files:** +- Create: `include/sql_parser/string_builder.h` +- Create: `include/sql_parser/emitter.h` +- Create: `tests/test_emitter.cpp` +- Modify: `Makefile.new` + +- [ ] **Step 1: Write tests for expression round-trips** + +Create `tests/test_emitter.cpp`: +```cpp +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLEmitterTest : public ::testing::Test { +protected: + Parser parser; + + // Parse, emit, return the emitted string + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== SET round-trips ========== + +TEST_F(MySQLEmitterTest, SetSimpleVariable) { + std::string out = round_trip("SET autocommit = 1"); + EXPECT_EQ(out, "SET autocommit = 1"); +} + +TEST_F(MySQLEmitterTest, SetMultipleVariables) { + std::string out = round_trip("SET autocommit = 1, wait_timeout = 28800"); + EXPECT_EQ(out, "SET autocommit = 1, wait_timeout = 28800"); +} + +TEST_F(MySQLEmitterTest, SetNames) { + std::string out = round_trip("SET NAMES utf8mb4"); + EXPECT_EQ(out, "SET NAMES utf8mb4"); +} + +TEST_F(MySQLEmitterTest, SetNamesCollate) { + std::string out = round_trip("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"); + EXPECT_EQ(out, "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"); +} + +TEST_F(MySQLEmitterTest, SetCharacterSet) { + std::string out = round_trip("SET CHARACTER SET utf8"); + EXPECT_EQ(out, "SET CHARACTER SET utf8"); +} + +TEST_F(MySQLEmitterTest, SetCharset) { + // CHARSET is normalized to CHARACTER SET in emitted output + std::string out = round_trip("SET CHARSET utf8"); + EXPECT_EQ(out, "SET CHARACTER SET utf8"); +} + +TEST_F(MySQLEmitterTest, SetGlobalVariable) { + std::string out = round_trip("SET GLOBAL max_connections = 100"); + EXPECT_EQ(out, "SET GLOBAL max_connections = 100"); +} + +TEST_F(MySQLEmitterTest, SetSessionVariable) { + std::string out = round_trip("SET SESSION wait_timeout = 600"); + EXPECT_EQ(out, "SET SESSION wait_timeout = 600"); +} + +TEST_F(MySQLEmitterTest, SetDoubleAtVariable) { + std::string out = round_trip("SET @@session.wait_timeout = 600"); + EXPECT_EQ(out, "SET @@session.wait_timeout = 600"); +} + +TEST_F(MySQLEmitterTest, SetUserVariable) { + std::string out = round_trip("SET @my_var = 42"); + EXPECT_EQ(out, "SET @my_var = 42"); +} + +TEST_F(MySQLEmitterTest, SetTransaction) { + std::string out = round_trip("SET TRANSACTION READ ONLY"); + EXPECT_EQ(out, "SET TRANSACTION READ ONLY"); +} + +TEST_F(MySQLEmitterTest, SetTransactionIsolation) { + // ISOLATION LEVEL keywords are consumed by parser; emitter outputs the level value directly + std::string out = round_trip("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + EXPECT_EQ(out, "SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + // Note: To support this, the SET parser must preserve "ISOLATION LEVEL" in the AST. + // The emitter's emit_set_transaction() must check children and re-insert the keywords. +} + +TEST_F(MySQLEmitterTest, SetFunctionRHS) { + std::string out = round_trip("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); + EXPECT_EQ(out, "SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); +} + +// ========== SELECT round-trips ========== + +TEST_F(MySQLEmitterTest, SelectLiteral) { + std::string out = round_trip("SELECT 1"); + EXPECT_EQ(out, "SELECT 1"); +} + +TEST_F(MySQLEmitterTest, SelectStar) { + std::string out = round_trip("SELECT * FROM users"); + EXPECT_EQ(out, "SELECT * FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectColumns) { + std::string out = round_trip("SELECT id, name FROM users"); + EXPECT_EQ(out, "SELECT id, name FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectWithAlias) { + std::string out = round_trip("SELECT id AS user_id FROM users"); + EXPECT_EQ(out, "SELECT id AS user_id FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectDistinct) { + std::string out = round_trip("SELECT DISTINCT name FROM users"); + EXPECT_EQ(out, "SELECT DISTINCT name FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectWhere) { + std::string out = round_trip("SELECT * FROM users WHERE id = 1"); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = 1"); +} + +TEST_F(MySQLEmitterTest, SelectWhereAnd) { + std::string out = round_trip("SELECT * FROM users WHERE age > 18 AND status = 'active'"); + EXPECT_EQ(out, "SELECT * FROM users WHERE age > 18 AND status = 'active'"); +} + +TEST_F(MySQLEmitterTest, SelectJoin) { + std::string out = round_trip("SELECT * FROM users JOIN orders ON users.id = orders.user_id"); + EXPECT_EQ(out, "SELECT * FROM users JOIN orders ON users.id = orders.user_id"); +} + +TEST_F(MySQLEmitterTest, SelectLeftJoin) { + std::string out = round_trip("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id"); + EXPECT_EQ(out, "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id"); +} + +TEST_F(MySQLEmitterTest, SelectGroupBy) { + std::string out = round_trip("SELECT status, COUNT(*) FROM users GROUP BY status"); + EXPECT_EQ(out, "SELECT status, COUNT(*) FROM users GROUP BY status"); +} + +TEST_F(MySQLEmitterTest, SelectGroupByHaving) { + std::string out = round_trip("SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5"); + EXPECT_EQ(out, "SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5"); +} + +TEST_F(MySQLEmitterTest, SelectOrderBy) { + std::string out = round_trip("SELECT * FROM users ORDER BY name ASC"); + EXPECT_EQ(out, "SELECT * FROM users ORDER BY name ASC"); +} + +TEST_F(MySQLEmitterTest, SelectLimit) { + std::string out = round_trip("SELECT * FROM users LIMIT 10"); + EXPECT_EQ(out, "SELECT * FROM users LIMIT 10"); +} + +TEST_F(MySQLEmitterTest, SelectLimitOffset) { + std::string out = round_trip("SELECT * FROM users LIMIT 10 OFFSET 20"); + EXPECT_EQ(out, "SELECT * FROM users LIMIT 10 OFFSET 20"); +} + +TEST_F(MySQLEmitterTest, SelectForUpdate) { + std::string out = round_trip("SELECT * FROM users FOR UPDATE"); + EXPECT_EQ(out, "SELECT * FROM users FOR UPDATE"); +} + +// ========== Expression round-trips ========== + +TEST_F(MySQLEmitterTest, ExprIsNull) { + std::string out = round_trip("SELECT * FROM t WHERE x IS NULL"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IS NULL"); +} + +TEST_F(MySQLEmitterTest, ExprIsNotNull) { + std::string out = round_trip("SELECT * FROM t WHERE x IS NOT NULL"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IS NOT NULL"); +} + +TEST_F(MySQLEmitterTest, ExprBetween) { + std::string out = round_trip("SELECT * FROM t WHERE x BETWEEN 1 AND 10"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x BETWEEN 1 AND 10"); +} + +TEST_F(MySQLEmitterTest, ExprIn) { + std::string out = round_trip("SELECT * FROM t WHERE x IN (1, 2, 3)"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IN (1, 2, 3)"); +} + +TEST_F(MySQLEmitterTest, ExprFunctionCall) { + std::string out = round_trip("SELECT COUNT(*) FROM users"); + EXPECT_EQ(out, "SELECT COUNT(*) FROM users"); +} + +TEST_F(MySQLEmitterTest, ExprUnaryMinus) { + std::string out = round_trip("SELECT -1"); + EXPECT_EQ(out, "SELECT -1"); +} + +// ========== Bulk round-trip tests ========== + +struct RoundTripCase { + const char* sql; + const char* description; +}; + +static const RoundTripCase roundtrip_cases[] = { + {"SET autocommit = 0", "set simple"}, + {"SET NAMES utf8", "set names"}, + {"SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci", "set names collate"}, + {"SET CHARACTER SET utf8", "set character set"}, + {"SET GLOBAL max_connections = 100", "set global"}, + {"SET @x = 42", "set user var"}, + {"SET @@session.wait_timeout = 600", "set sys var"}, + {"SELECT 1", "select literal"}, + {"SELECT * FROM t", "select star"}, + {"SELECT a, b FROM t", "select columns"}, + {"SELECT a AS x FROM t", "select alias"}, + {"SELECT DISTINCT a FROM t", "select distinct"}, + {"SELECT * FROM t WHERE a = 1", "select where"}, + {"SELECT * FROM t WHERE a > 1 AND b < 10", "select where and"}, + {"SELECT * FROM t ORDER BY a", "select order by"}, + {"SELECT * FROM t ORDER BY a DESC", "select order by desc"}, + {"SELECT * FROM t LIMIT 10", "select limit"}, + {"SELECT * FROM t LIMIT 10 OFFSET 5", "select limit offset"}, + {"SELECT * FROM t FOR UPDATE", "select for update"}, + {"SELECT COUNT(*) FROM t", "select count"}, + {"SELECT * FROM t WHERE x IS NULL", "is null"}, + {"SELECT * FROM t WHERE x IS NOT NULL", "is not null"}, + {"SELECT * FROM t WHERE x IN (1, 2, 3)", "in list"}, + {"SELECT * FROM t WHERE x BETWEEN 1 AND 10", "between"}, +}; + +TEST(MySQLEmitterBulk, RoundTripsMatch) { + Parser parser; + for (const auto& tc : roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +// ========== AST modification tests ========== + +TEST_F(MySQLEmitterTest, ModifySetValue) { + // Parse SET autocommit = 1, modify value to 0, emit + auto r = parser.parse("SET autocommit = 1", 18); + ASSERT_NE(r.ast, nullptr); + + // Navigate to the value node: SET_STMT -> VAR_ASSIGNMENT -> (target, value) + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + ASSERT_EQ(assignment->type, NodeType::NODE_VAR_ASSIGNMENT); + + // Second child of assignment is the RHS value + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* value = target->next_sibling; + ASSERT_NE(value, nullptr); + + // Modify the value + const char* new_val = "0"; + value->value_ptr = new_val; + value->value_len = 1; + + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET autocommit = 0"); +} + +// ========== PostgreSQL round-trips ========== + +TEST(PgSQLEmitterTest, SetVarTo) { + // PostgreSQL TO is normalized to = in emitted output + Parser parser; + auto r = parser.parse("SET client_encoding TO 'UTF8'", 29); + ASSERT_NE(r.ast, nullptr); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET client_encoding = 'UTF8'"); +} + +TEST(PgSQLEmitterTest, SelectBasic) { + Parser parser; + auto r = parser.parse("SELECT * FROM users WHERE id = 1", 32); + ASSERT_NE(r.ast, nullptr); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = 1"); +} +``` + +- [ ] **Step 2: Create string_builder.h** + +Create `include/sql_parser/string_builder.h`: +```cpp +#ifndef SQL_PARSER_STRING_BUILDER_H +#define SQL_PARSER_STRING_BUILDER_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include + +namespace sql_parser { + +// Arena-backed string builder for emitting SQL. +// Builds a string by appending chunks. The final result is a contiguous +// StringRef obtained via finish(). All memory is arena-allocated. +class StringBuilder { +public: + explicit StringBuilder(Arena& arena, size_t initial_capacity = 1024) + : arena_(arena), capacity_(initial_capacity), len_(0) { + buf_ = static_cast(arena_.allocate(capacity_)); + } + + void append(const char* s, size_t n) { + ensure_capacity(n); + if (buf_) { + std::memcpy(buf_ + len_, s, n); + len_ += n; + } + } + + void append(StringRef ref) { + if (ref.ptr && ref.len > 0) { + append(ref.ptr, ref.len); + } + } + + void append(const char* s) { + append(s, std::strlen(s)); + } + + void append_char(char c) { + ensure_capacity(1); + if (buf_) { + buf_[len_++] = c; + } + } + + // Append a space if the last character isn't already a space + void space() { + if (len_ > 0 && buf_[len_ - 1] != ' ') { + append_char(' '); + } + } + + StringRef finish() { + return StringRef{buf_, static_cast(len_)}; + } + + size_t length() const { return len_; } + +private: + Arena& arena_; + char* buf_; + size_t capacity_; + size_t len_; + + void ensure_capacity(size_t additional) { + if (!buf_) return; + if (len_ + additional <= capacity_) return; + + size_t new_cap = capacity_ * 2; + while (new_cap < len_ + additional) new_cap *= 2; + + char* new_buf = static_cast(arena_.allocate(new_cap)); + if (new_buf) { + std::memcpy(new_buf, buf_, len_); + } + buf_ = new_buf; + capacity_ = new_cap; + // Old buffer is abandoned in the arena — freed on arena reset + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_STRING_BUILDER_H +``` + +- [ ] **Step 3: Create emitter.h** + +Create `include/sql_parser/emitter.h`: +```cpp +#ifndef SQL_PARSER_EMITTER_H +#define SQL_PARSER_EMITTER_H + +#include "sql_parser/common.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/string_builder.h" + +namespace sql_parser { + +template +class Emitter { +public: + explicit Emitter(Arena& arena) : sb_(arena) {} + + void emit(const AstNode* node) { + if (!node) return; + emit_node(node); + } + + StringRef result() { return sb_.finish(); } + +private: + StringBuilder sb_; + + void emit_node(const AstNode* node) { + switch (node->type) { + // ---- SET statement ---- + case NodeType::NODE_SET_STMT: emit_set_stmt(node); break; + case NodeType::NODE_SET_NAMES: emit_set_names(node); break; + case NodeType::NODE_SET_CHARSET: emit_set_charset(node); break; + case NodeType::NODE_SET_TRANSACTION: emit_set_transaction(node); break; + case NodeType::NODE_VAR_ASSIGNMENT: emit_var_assignment(node); break; + case NodeType::NODE_VAR_TARGET: emit_var_target(node); break; + + // ---- SELECT statement ---- + case NodeType::NODE_SELECT_STMT: emit_select_stmt(node); break; + case NodeType::NODE_SELECT_OPTIONS: emit_select_options(node); break; + case NodeType::NODE_SELECT_ITEM_LIST:emit_select_item_list(node); break; + case NodeType::NODE_SELECT_ITEM: emit_select_item(node); break; + case NodeType::NODE_FROM_CLAUSE: emit_from_clause(node); break; + case NodeType::NODE_JOIN_CLAUSE: emit_join_clause(node); break; + case NodeType::NODE_WHERE_CLAUSE: emit_where_clause(node); break; + case NodeType::NODE_GROUP_BY_CLAUSE: emit_group_by(node); break; + case NodeType::NODE_HAVING_CLAUSE: emit_having(node); break; + case NodeType::NODE_ORDER_BY_CLAUSE: emit_order_by(node); break; + case NodeType::NODE_ORDER_BY_ITEM: emit_order_by_item(node); break; + case NodeType::NODE_LIMIT_CLAUSE: emit_limit(node); break; + case NodeType::NODE_LOCKING_CLAUSE: emit_locking(node); break; + case NodeType::NODE_INTO_CLAUSE: emit_into(node); break; + + // ---- Table references ---- + case NodeType::NODE_TABLE_REF: emit_table_ref(node); break; + case NodeType::NODE_ALIAS: emit_alias(node); break; + case NodeType::NODE_QUALIFIED_NAME: emit_qualified_name(node); break; + + // ---- Expressions ---- + case NodeType::NODE_BINARY_OP: emit_binary_op(node); break; + case NodeType::NODE_UNARY_OP: emit_unary_op(node); break; + case NodeType::NODE_FUNCTION_CALL: emit_function_call(node); break; + case NodeType::NODE_IS_NULL: emit_is_null(node); break; + case NodeType::NODE_IS_NOT_NULL: emit_is_not_null(node); break; + case NodeType::NODE_BETWEEN: emit_between(node); break; + case NodeType::NODE_IN_LIST: emit_in_list(node); break; + case NodeType::NODE_CASE_WHEN: emit_case_when(node); break; + case NodeType::NODE_SUBQUERY: emit_value(node); break; + + // ---- Leaf nodes (emit value directly) ---- + case NodeType::NODE_LITERAL_INT: + case NodeType::NODE_LITERAL_FLOAT: + case NodeType::NODE_LITERAL_NULL: + case NodeType::NODE_COLUMN_REF: + case NodeType::NODE_ASTERISK: + case NodeType::NODE_PLACEHOLDER: + case NodeType::NODE_IDENTIFIER: + emit_value(node); break; + + case NodeType::NODE_LITERAL_STRING: + emit_string_literal(node); break; + + default: + emit_value(node); break; + } + } + + void emit_value(const AstNode* node) { + sb_.append(node->value_ptr, node->value_len); + } + + void emit_string_literal(const AstNode* node) { + sb_.append_char('\''); + sb_.append(node->value_ptr, node->value_len); + sb_.append_char('\''); + } + + // ---- SET ---- + + void emit_set_stmt(const AstNode* node) { + sb_.append("SET "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_set_names(const AstNode* node) { + sb_.append("NAMES "); + const AstNode* charset = node->first_child; + if (charset) emit_node(charset); + const AstNode* collation = charset ? charset->next_sibling : nullptr; + if (collation) { + sb_.append(" COLLATE "); + emit_node(collation); + } + } + + void emit_set_charset(const AstNode* node) { + // Detect if original was CHARACTER SET or CHARSET from value + sb_.append("CHARACTER SET "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_set_transaction(const AstNode* node) { + sb_.append("TRANSACTION "); + const AstNode* child = node->first_child; + // First child may be scope (GLOBAL/SESSION) + if (child && child->value_len > 0) { + StringRef val = child->value(); + // Check if this is a scope keyword + if (val.equals_ci("GLOBAL", 6) || val.equals_ci("SESSION", 7) || + val.equals_ci("LOCAL", 5)) { + // This was already emitted before TRANSACTION by the SET stmt emitter + child = child->next_sibling; + } + } + if (child) { + StringRef val = child->value(); + // Check if this is an isolation level or access mode + if (val.equals_ci("READ ONLY", 9) || val.equals_ci("READ WRITE", 10)) { + emit_node(child); + } else { + // It's an isolation level value + sb_.append("ISOLATION LEVEL "); + emit_node(child); + } + } + } + + void emit_var_assignment(const AstNode* node) { + const AstNode* target = node->first_child; + const AstNode* rhs = target ? target->next_sibling : nullptr; + + if (target) emit_node(target); + sb_.append(" = "); + if (rhs) emit_node(rhs); + } + + void emit_var_target(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + // ---- SELECT ---- + + void emit_select_stmt(const AstNode* node) { + sb_.append("SELECT "); + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + } + } + + void emit_select_options(const AstNode* node) { + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + sb_.append_char(' '); + } + } + + void emit_select_item_list(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_select_item(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + const AstNode* alias = expr ? expr->next_sibling : nullptr; + if (alias && alias->type == NodeType::NODE_ALIAS) { + emit_node(alias); + } + } + + void emit_from_clause(const AstNode* node) { + sb_.append(" FROM "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_JOIN_CLAUSE) { + sb_.append_char(' '); + emit_node(child); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + + void emit_table_ref(const AstNode* node) { + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + } + } + + void emit_alias(const AstNode* node) { + sb_.append(" AS "); + emit_value(node); + } + + void emit_qualified_name(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char('.'); + first = false; + emit_node(child); + } + } + + void emit_join_clause(const AstNode* node) { + // Join type stored in node value + emit_value(node); + sb_.append_char(' '); + // Children: table_ref, [ON expr | USING (...)] + const AstNode* table = node->first_child; + if (table) { + emit_node(table); + } + const AstNode* condition = table ? table->next_sibling : nullptr; + if (condition) { + if (condition->type == NodeType::NODE_IDENTIFIER && + condition->value_len == 5 && + std::memcmp(condition->value_ptr, "USING", 5) == 0) { + sb_.append(" USING ("); + bool first = true; + for (const AstNode* col = condition->first_child; col; col = col->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(col); + } + sb_.append_char(')'); + } else { + sb_.append(" ON "); + emit_node(condition); + } + } + } + + void emit_where_clause(const AstNode* node) { + sb_.append(" WHERE "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_group_by(const AstNode* node) { + sb_.append(" GROUP BY "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_having(const AstNode* node) { + sb_.append(" HAVING "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_order_by(const AstNode* node) { + sb_.append(" ORDER BY "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_order_by_item(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + const AstNode* dir = expr ? expr->next_sibling : nullptr; + if (dir) { + sb_.append_char(' '); + emit_node(dir); + } + } + + void emit_limit(const AstNode* node) { + sb_.append(" LIMIT "); + const AstNode* first_val = node->first_child; + if (first_val) emit_node(first_val); + const AstNode* second_val = first_val ? first_val->next_sibling : nullptr; + if (second_val) { + sb_.append(" OFFSET "); + emit_node(second_val); + } + } + + void emit_locking(const AstNode* node) { + sb_.append(" FOR "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + void emit_into(const AstNode* node) { + sb_.append(" INTO "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + // ---- Expressions ---- + + void emit_binary_op(const AstNode* node) { + const AstNode* left = node->first_child; + const AstNode* right = left ? left->next_sibling : nullptr; + if (left) emit_node(left); + sb_.append_char(' '); + emit_value(node); + sb_.append_char(' '); + if (right) emit_node(right); + } + + void emit_unary_op(const AstNode* node) { + emit_value(node); + // Add space for keyword operators like NOT, no space for - or + + if (node->value_len > 1) sb_.append_char(' '); + if (node->first_child) emit_node(node->first_child); + } + + void emit_function_call(const AstNode* node) { + emit_value(node); + sb_.append_char('('); + bool first = true; + for (const AstNode* arg = node->first_child; arg; arg = arg->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(arg); + } + sb_.append_char(')'); + } + + void emit_is_null(const AstNode* node) { + if (node->first_child) emit_node(node->first_child); + sb_.append(" IS NULL"); + } + + void emit_is_not_null(const AstNode* node) { + if (node->first_child) emit_node(node->first_child); + sb_.append(" IS NOT NULL"); + } + + void emit_between(const AstNode* node) { + const AstNode* expr = node->first_child; + const AstNode* low = expr ? expr->next_sibling : nullptr; + const AstNode* high = low ? low->next_sibling : nullptr; + if (expr) emit_node(expr); + sb_.append(" BETWEEN "); + if (low) emit_node(low); + sb_.append(" AND "); + if (high) emit_node(high); + } + + void emit_in_list(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + sb_.append(" IN ("); + bool first = true; + for (const AstNode* val = expr ? expr->next_sibling : nullptr; val; val = val->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(val); + } + sb_.append_char(')'); + } + + void emit_case_when(const AstNode* node) { + sb_.append("CASE "); + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + sb_.append_char(' '); + } + sb_.append("END"); + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_EMITTER_H +``` + +- [ ] **Step 4: Expose arena from Parser and add emit convenience** + +Modify `include/sql_parser/parser.h` — add a public `arena()` accessor: + +Add after `void reset();`: +```cpp + // Access the arena (for emitter use) + Arena& arena() { return arena_; } +``` + +- [ ] **Step 5: Update Makefile.new** + +Add `$(TEST_DIR)/test_emitter.cpp` to `TEST_SRCS`. + +- [ ] **Step 6: Build and run tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` + +The round-trip tests verify that `parse(sql) → emit(ast)` produces output identical to the input. If some tests fail due to spacing/formatting differences between original SQL and emitted SQL, adjust the emitter to match. The emitter's job is to produce **semantically equivalent** SQL — exact whitespace preservation is not required, but we aim for faithful reproduction. + +- [ ] **Step 7: Commit** + +```bash +git add include/sql_parser/string_builder.h include/sql_parser/emitter.h \ + include/sql_parser/parser.h tests/test_emitter.cpp Makefile.new +git commit -m "feat: add query emitter with round-trip support for SET and SELECT" +``` + +--- + +### Task 2: Fix Round-Trip Failures and Edge Cases + +This task is for fixing any round-trip test failures from Task 1. The emitter needs to produce output that matches the original input for unmodified ASTs. Common issues: + +- Extra or missing spaces +- Keyword casing (emitter should use same case as original when possible) +- String quoting (emitter must re-add quotes around string literals) +- SET CHARSET vs CHARACTER SET (need to detect which was used) + +- [ ] **Step 1: Run tests and identify failures** + +```bash +make -f Makefile.new clean && make -f Makefile.new all 2>&1 | grep FAILED +``` + +- [ ] **Step 2: Fix each failure** + +For each failing test, trace through the emitter logic and fix the output. Common fixes: +- Adjust spacing in `emit_select_stmt` (no trailing space before FROM) +- Handle string literal quoting in `emit_string_literal` +- Preserve original keyword text from AST node values + +- [ ] **Step 3: Re-run all tests** + +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All tests pass. + +- [ ] **Step 4: Commit fixes** + +```bash +git add include/sql_parser/emitter.h tests/test_emitter.cpp +git commit -m "fix: correct emitter output for round-trip test compliance" +``` + +--- + +## What's Next + +After this plan is complete: + +1. **Plan 5: Prepared Statement Cache** — Binary protocol support +2. **Plan 6: Performance Benchmarks** — Validate latency targets diff --git a/docs/superpowers/plans/2026-03-24-select-parser.md b/docs/superpowers/plans/2026-03-24-select-parser.md new file mode 100644 index 0000000..9ccc913 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-select-parser.md @@ -0,0 +1,1222 @@ +# SELECT Deep Parser Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the full SELECT statement deep parser, upgrading SELECT from a Tier 2 stub to a Tier 1 parser that produces a complete AST with all clauses (FROM, JOIN, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, locking). + +**Architecture:** The SELECT parser is a header-only template class `SelectParser` that follows the same pattern as `SetParser`. It uses the existing `ExpressionParser` for all expression positions (select items, WHERE conditions, HAVING, JOIN ON, etc.). Each clause is a separate parse method. The parser is lenient — it produces as much AST as it can, even for partial/malformed queries. + +**Tech Stack:** C++17, existing arena/tokenizer/expression_parser infrastructure + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` (Tier 1 SELECT Parser section) + +--- + +## Scope + +This plan builds: +1. SELECT parser (`select_parser.h`) — full AST for all SELECT clauses +2. Integration into `Parser` — `parse_select()` upgraded from stub to real parser +3. Comprehensive tests from simple to complex SELECT statements + +**Not in scope:** Query emitter/reconstruction, prepared statement cache, UNION/INTERSECT/EXCEPT. + +--- + +## File Structure + +``` +include/sql_parser/ + select_parser.h — SELECT statement parser (header-only template) + +src/sql_parser/ + parser.cpp — (modify) Replace parse_select() stub, add #include + +tests/ + test_select.cpp — SELECT parser tests + +Makefile.new — (modify) Add test_select.cpp to TEST_SRCS +``` + +--- + +### Task 1: SELECT Parser — Basic SELECT and FROM + +**Files:** +- Create: `include/sql_parser/select_parser.h` +- Create: `tests/test_select.cpp` +- Modify: `Makefile.new` — add test_select.cpp +- Modify: `src/sql_parser/parser.cpp` — replace parse_select() stub + +This task implements the core SELECT structure: select options, select item list (with aliases), and FROM clause with simple table references. + +- [ ] **Step 1: Write tests for basic SELECT** + +Create `tests/test_select.cpp`: +```cpp +#include +#include "sql_parser/parser.h" + +using namespace sql_parser; + +class MySQLSelectTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } +}; + +// ========== Basic SELECT ========== + +TEST_F(MySQLSelectTest, SelectLiteral) { + auto r = parser.parse("SELECT 1", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SELECT_STMT); +} + +TEST_F(MySQLSelectTest, SelectStar) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(MySQLSelectTest, SelectColumns) { + auto r = parser.parse("SELECT id, name, email FROM users", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + EXPECT_EQ(child_count(items), 3); +} + +TEST_F(MySQLSelectTest, SelectWithAlias) { + auto r = parser.parse("SELECT id AS user_id, name AS user_name FROM users", 50); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + // Each item should have an alias child + auto* first_item = items->first_child; + ASSERT_NE(first_item, nullptr); + EXPECT_EQ(first_item->type, NodeType::NODE_SELECT_ITEM); + auto* alias = find_child(first_item, NodeType::NODE_ALIAS); + ASSERT_NE(alias, nullptr); +} + +TEST_F(MySQLSelectTest, SelectImplicitAlias) { + // Alias without AS keyword + auto r = parser.parse("SELECT id user_id FROM users", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectDistinct) { + auto r = parser.parse("SELECT DISTINCT name FROM users", 31); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_SELECT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLSelectTest, SelectSqlCalcFoundRows) { + auto r = parser.parse("SELECT SQL_CALC_FOUND_ROWS * FROM users", 40); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromQualifiedTable) { + auto r = parser.parse("SELECT * FROM mydb.users", 24); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromTableAlias) { + auto r = parser.parse("SELECT u.id FROM users u", 24); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromTableAsAlias) { + auto r = parser.parse("SELECT u.id FROM users AS u", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromMultipleTables) { + auto r = parser.parse("SELECT * FROM users, orders", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); + EXPECT_GE(child_count(from), 2); +} + +TEST_F(MySQLSelectTest, SelectExpression) { + auto r = parser.parse("SELECT 1 + 2, 'hello', NOW()", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectNoFrom) { + auto r = parser.parse("SELECT 1, 'a', NOW()", 20); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + // No FROM clause + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + EXPECT_EQ(from, nullptr); +} +``` + +- [ ] **Step 2: Add TK_UNION, TK_OF, and TK_EXISTS to token.h and keyword tables** + +These tokens are referenced by `select_parser.h` and `expression_parser.h`, so they must exist before those files compile. + +Add `TK_UNION`, `TK_OF`, and `TK_EXISTS` to `TokenType` enum in `include/sql_parser/token.h` (after `TK_RESET`): +```cpp +TK_UNION, +TK_OF, +TK_EXISTS, +``` + +Note: `TK_EXISTS` is already in the enum from Plan 1. Verify it exists; if not, add it. + +Add to `include/sql_parser/keywords_mysql.h` sorted array: +- `{"OF", 2, TokenType::TK_OF},` between `NULL` and `OFFSET` +- `{"UNION", 5, TokenType::TK_UNION},` between `TRUNCATE` and `UNCOMMITTED` + +Add same entries to `include/sql_parser/keywords_pgsql.h` at the same sorted positions. + +Also add `EXISTS` handling to the expression parser. In `include/sql_parser/expression_parser.h`, add a case in `parse_atom()` before the `TK_LPAREN` case: +```cpp + case TokenType::TK_EXISTS: { + tok_.skip(); + // EXISTS (subquery) + AstNode* node = make_node(arena_, NodeType::NODE_SUBQUERY); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + skip_to_matching_paren(); + } + return node; + } +``` + +- [ ] **Step 3: Create select_parser.h** + +Create `include/sql_parser/select_parser.h`: +```cpp +#ifndef SQL_PARSER_SELECT_PARSER_H +#define SQL_PARSER_SELECT_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class SelectParser { +public: + SelectParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {} + + // Parse a SELECT statement (SELECT keyword already consumed by classifier). + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_SELECT_STMT); + if (!root) return nullptr; + + // SELECT options: DISTINCT, ALL, SQL_CALC_FOUND_ROWS + AstNode* opts = parse_select_options(); + if (opts) root->add_child(opts); + + // Select item list + AstNode* items = parse_select_item_list(); + if (items) root->add_child(items); + + // INTO (before FROM in some MySQL variants — skip for now, handle after FROM) + + // FROM clause + if (tok_.peek().type == TokenType::TK_FROM) { + tok_.skip(); + AstNode* from = parse_from_clause(); + if (from) root->add_child(from); + } + + // WHERE clause + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + + // GROUP BY clause + if (tok_.peek().type == TokenType::TK_GROUP) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* group_by = parse_group_by(); + if (group_by) root->add_child(group_by); + } + + // HAVING clause + if (tok_.peek().type == TokenType::TK_HAVING) { + tok_.skip(); + AstNode* having = parse_having(); + if (having) root->add_child(having); + } + + // ORDER BY clause + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) root->add_child(order_by); + } + + // LIMIT clause + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) root->add_child(limit); + } + + // FOR UPDATE / FOR SHARE (locking) + if (tok_.peek().type == TokenType::TK_FOR) { + AstNode* lock = parse_locking(); + if (lock) root->add_child(lock); + } + + // INTO (MySQL: can appear here too — INTO OUTFILE/DUMPFILE/var) + if constexpr (D == Dialect::MySQL) { + if (tok_.peek().type == TokenType::TK_INTO) { + AstNode* into = parse_into(); + if (into) root->add_child(into); + } + } + + return root; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + + // ---- SELECT options ---- + + AstNode* parse_select_options() { + AstNode* opts = nullptr; + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_DISTINCT || t.type == TokenType::TK_ALL) { + if (!opts) opts = make_node(arena_, NodeType::NODE_SELECT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else if (t.type == TokenType::TK_SQL_CALC_FOUND_ROWS) { + if (!opts) opts = make_node(arena_, NodeType::NODE_SELECT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else { + break; + } + } + return opts; + } + + // ---- Select item list ---- + + AstNode* parse_select_item_list() { + AstNode* list = make_node(arena_, NodeType::NODE_SELECT_ITEM_LIST); + if (!list) return nullptr; + + while (true) { + AstNode* item = parse_select_item(); + if (!item) break; + list->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return list; + } + + AstNode* parse_select_item() { + AstNode* item = make_node(arena_, NodeType::NODE_SELECT_ITEM); + if (!item) return nullptr; + + AstNode* expr = expr_parser_.parse(); + if (!expr) return nullptr; + item->add_child(expr); + + // Optional alias: AS name, or just name (implicit alias) + Token next = tok_.peek(); + if (next.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + AstNode* alias = make_node(arena_, NodeType::NODE_ALIAS, alias_name.text); + item->add_child(alias); + } else if (is_alias_start(next.type)) { + // Implicit alias (no AS keyword): SELECT expr alias_name + tok_.skip(); + AstNode* alias = make_node(arena_, NodeType::NODE_ALIAS, next.text); + item->add_child(alias); + } + return item; + } + + // ---- FROM clause ---- + + AstNode* parse_from_clause() { + AstNode* from = make_node(arena_, NodeType::NODE_FROM_CLAUSE); + if (!from) return nullptr; + + // First table reference + AstNode* table_ref = parse_table_reference(); + if (table_ref) from->add_child(table_ref); + + // Additional table refs (comma join) or explicit JOINs + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_COMMA) { + // Comma join: FROM t1, t2 + tok_.skip(); + AstNode* next_ref = parse_table_reference(); + if (next_ref) from->add_child(next_ref); + } else if (is_join_start(t.type)) { + // Explicit JOIN + AstNode* join = parse_join(from->first_child); + if (join) { + // Replace the last table ref with the join node + // Actually, append the join as a child of FROM + from->add_child(join); + } + } else { + break; + } + } + + return from; + } + + AstNode* parse_table_reference() { + Token t = tok_.peek(); + + // Subquery: (SELECT ...) + if (t.type == TokenType::TK_LPAREN) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_SELECT) { + AstNode* subq = make_node(arena_, NodeType::NODE_SUBQUERY); + // Skip to matching paren + int depth = 1; + while (depth > 0) { + Token st = tok_.next_token(); + if (st.type == TokenType::TK_LPAREN) ++depth; + else if (st.type == TokenType::TK_RPAREN) --depth; + else if (st.type == TokenType::TK_EOF) break; + } + // Optional alias + AstNode* ref = make_node(arena_, NodeType::NODE_TABLE_REF); + ref->add_child(subq); + parse_optional_alias(ref); + return ref; + } + // Parenthesized table reference — parse inner + AstNode* inner = parse_table_reference(); + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + return inner; + } + + // Simple table name or schema.table + AstNode* ref = make_node(arena_, NodeType::NODE_TABLE_REF); + Token name = tok_.next_token(); + + if (tok_.peek().type == TokenType::TK_DOT) { + // Qualified: schema.table + tok_.skip(); + Token table_name = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, table_name.text)); + ref->add_child(qname); + } else { + ref->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + } + + // Optional alias + parse_optional_alias(ref); + return ref; + } + + void parse_optional_alias(AstNode* parent) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + parent->add_child(make_node(arena_, NodeType::NODE_ALIAS, alias_name.text)); + } else if (is_alias_start(t.type)) { + tok_.skip(); + parent->add_child(make_node(arena_, NodeType::NODE_ALIAS, t.text)); + } + } + + // ---- JOIN ---- + + AstNode* parse_join(AstNode* /* left_ref */) { + AstNode* join = make_node(arena_, NodeType::NODE_JOIN_CLAUSE); + if (!join) return nullptr; + + // Consume join type tokens + Token t = tok_.peek(); + StringRef join_type_start = t.text; + StringRef join_type_end = t.text; + + // Optional: NATURAL, LEFT, RIGHT, FULL, INNER, OUTER, CROSS + while (t.type == TokenType::TK_NATURAL || t.type == TokenType::TK_LEFT || + t.type == TokenType::TK_RIGHT || t.type == TokenType::TK_FULL || + t.type == TokenType::TK_INNER || t.type == TokenType::TK_OUTER || + t.type == TokenType::TK_CROSS) { + tok_.skip(); + join_type_end = t.text; + t = tok_.peek(); + } + + // Expect JOIN keyword + if (t.type == TokenType::TK_JOIN) { + join_type_end = t.text; + tok_.skip(); + } + + // Set join type as value (covers the span from first modifier to JOIN) + StringRef join_type{join_type_start.ptr, + static_cast((join_type_end.ptr + join_type_end.len) - join_type_start.ptr)}; + join->value_ptr = join_type.ptr; + join->value_len = join_type.len; + + // Right table reference + AstNode* right_ref = parse_table_reference(); + if (right_ref) join->add_child(right_ref); + + // Join condition: ON expr or USING (col_list) + if (tok_.peek().type == TokenType::TK_ON) { + tok_.skip(); + AstNode* on_expr = expr_parser_.parse(); + if (on_expr) join->add_child(on_expr); + } else if (tok_.peek().type == TokenType::TK_USING) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + AstNode* using_list = make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"USING", 5}); + while (true) { + Token col = tok_.next_token(); + using_list->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + join->add_child(using_list); + } + } + + return join; + } + + // ---- WHERE ---- + + AstNode* parse_where_clause() { + AstNode* where = make_node(arena_, NodeType::NODE_WHERE_CLAUSE); + if (!where) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) where->add_child(expr); + return where; + } + + // ---- GROUP BY ---- + + AstNode* parse_group_by() { + AstNode* group_by = make_node(arena_, NodeType::NODE_GROUP_BY_CLAUSE); + if (!group_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + group_by->add_child(expr); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return group_by; + } + + // ---- HAVING ---- + + AstNode* parse_having() { + AstNode* having = make_node(arena_, NodeType::NODE_HAVING_CLAUSE); + if (!having) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) having->add_child(expr); + return having; + } + + // ---- ORDER BY ---- + + AstNode* parse_order_by() { + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (!order_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + + // Optional ASC/DESC + Token dir = tok_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tok_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + + order_by->add_child(item); + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return order_by; + } + + // ---- LIMIT ---- + + AstNode* parse_limit() { + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (!limit) return nullptr; + + // LIMIT count [OFFSET offset] or LIMIT offset, count (MySQL) + AstNode* first = expr_parser_.parse(); + if (first) limit->add_child(first); + + if (tok_.peek().type == TokenType::TK_OFFSET) { + tok_.skip(); + AstNode* offset = expr_parser_.parse(); + if (offset) limit->add_child(offset); + } else if (tok_.peek().type == TokenType::TK_COMMA) { + // MySQL: LIMIT offset, count + tok_.skip(); + AstNode* count = expr_parser_.parse(); + if (count) limit->add_child(count); + } + + if constexpr (D == Dialect::PostgreSQL) { + // PostgreSQL also supports FETCH FIRST N ROWS ONLY after LIMIT/OFFSET + // We handle OFFSET here too since PgSQL uses LIMIT x OFFSET y + } + + return limit; + } + + // ---- FOR UPDATE / FOR SHARE ---- + + AstNode* parse_locking() { + AstNode* lock = make_node(arena_, NodeType::NODE_LOCKING_CLAUSE); + if (!lock) return nullptr; + + tok_.skip(); // consume FOR + Token strength = tok_.next_token(); // UPDATE or SHARE + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, strength.text)); + + // Optional: OF table_list + if (tok_.peek().type == TokenType::TK_OF) { + tok_.skip(); + while (true) { + Token table = tok_.next_token(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, table.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + + // Optional: NOWAIT or SKIP LOCKED + if (tok_.peek().type == TokenType::TK_NOWAIT) { + tok_.skip(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"NOWAIT", 6})); + } else if (tok_.peek().type == TokenType::TK_SKIP) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LOCKED) tok_.skip(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"SKIP LOCKED", 11})); + } + + return lock; + } + + // ---- INTO (MySQL: INTO OUTFILE/DUMPFILE/@var) ---- + + AstNode* parse_into() { + AstNode* into = make_node(arena_, NodeType::NODE_INTO_CLAUSE); + if (!into) return nullptr; + + tok_.skip(); // consume INTO + Token t = tok_.peek(); + + if (t.type == TokenType::TK_OUTFILE) { + tok_.skip(); + Token filename = tok_.next_token(); + into->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, + StringRef{"OUTFILE", 7})); + into->add_child(make_node(arena_, NodeType::NODE_LITERAL_STRING, filename.text)); + } else if (t.type == TokenType::TK_DUMPFILE) { + tok_.skip(); + Token filename = tok_.next_token(); + into->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, + StringRef{"DUMPFILE", 8})); + into->add_child(make_node(arena_, NodeType::NODE_LITERAL_STRING, filename.text)); + } else { + // INTO @var1, @var2, ... + while (true) { + AstNode* var = expr_parser_.parse(); + if (var) into->add_child(var); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + + return into; + } + + // ---- Helpers ---- + + static bool is_join_start(TokenType type) { + return type == TokenType::TK_JOIN || type == TokenType::TK_INNER || + type == TokenType::TK_LEFT || type == TokenType::TK_RIGHT || + type == TokenType::TK_FULL || type == TokenType::TK_OUTER || + type == TokenType::TK_CROSS || type == TokenType::TK_NATURAL; + } + + // Check if a token can start an implicit alias (identifier-like, not a clause keyword) + static bool is_alias_start(TokenType type) { + if (type == TokenType::TK_IDENTIFIER) return true; + // Some keywords are NOT valid alias starts because they start clauses + switch (type) { + case TokenType::TK_FROM: + case TokenType::TK_WHERE: + case TokenType::TK_GROUP: + case TokenType::TK_HAVING: + case TokenType::TK_ORDER: + case TokenType::TK_LIMIT: + case TokenType::TK_FOR: + case TokenType::TK_INTO: + case TokenType::TK_JOIN: + case TokenType::TK_INNER: + case TokenType::TK_LEFT: + case TokenType::TK_RIGHT: + case TokenType::TK_FULL: + case TokenType::TK_OUTER: + case TokenType::TK_CROSS: + case TokenType::TK_NATURAL: + case TokenType::TK_ON: + case TokenType::TK_USING: + case TokenType::TK_UNION: + case TokenType::TK_SEMICOLON: + case TokenType::TK_RPAREN: + case TokenType::TK_EOF: + case TokenType::TK_COMMA: + case TokenType::TK_SET: + case TokenType::TK_LOCK: + case TokenType::TK_UNLOCK: + return false; + default: + return true; // Keywords not in the blocklist can be implicit aliases + } + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_SELECT_PARSER_H +``` + +- [ ] **Step 4: Integrate into Parser class** + +Modify `src/sql_parser/parser.cpp` — add include and replace stub: + +Add after existing includes: +```cpp +#include "sql_parser/select_parser.h" +``` + +Replace `parse_select()`: +```cpp +template +ParseResult Parser::parse_select() { + ParseResult r; + r.stmt_type = StmtType::SELECT; + + SelectParser select_parser(tokenizer_, arena_); + AstNode* ast = select_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} +``` + +- [ ] **Step 5: Update Makefile.new** + +Add `$(TEST_DIR)/test_select.cpp` to `TEST_SRCS` in `Makefile.new`. + +- [ ] **Step 6: Build and run tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All tests pass. + +- [ ] **Step 7: Commit** + +```bash +git add include/sql_parser/select_parser.h include/sql_parser/token.h \ + include/sql_parser/keywords_mysql.h include/sql_parser/keywords_pgsql.h \ + src/sql_parser/parser.cpp tests/test_select.cpp Makefile.new +git commit -m "feat: add SELECT deep parser with FROM, WHERE, GROUP BY, ORDER BY, LIMIT, JOIN" +``` + +--- + +### Task 2: Comprehensive SELECT Tests — JOINs, Subqueries, Complex Queries + +**Files:** +- Modify: `tests/test_select.cpp` — add extensive tests + +- [ ] **Step 1: Add JOIN tests** + +Append to `tests/test_select.cpp`: +```cpp +// ========== JOINs ========== + +TEST_F(MySQLSelectTest, InnerJoin) { + auto r = parser.parse("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id", 66); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(MySQLSelectTest, LeftJoin) { + auto r = parser.parse("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id", 65); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, RightJoin) { + auto r = parser.parse("SELECT * FROM users RIGHT JOIN orders ON users.id = orders.user_id", 66); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, LeftOuterJoin) { + auto r = parser.parse("SELECT * FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id", 71); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, CrossJoin) { + auto r = parser.parse("SELECT * FROM users CROSS JOIN orders", 37); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, NaturalJoin) { + auto r = parser.parse("SELECT * FROM users NATURAL JOIN orders", 39); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, JoinUsing) { + auto r = parser.parse("SELECT * FROM users JOIN orders USING (user_id)", 48); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, MultipleJoins) { + const char* sql = "SELECT * FROM users JOIN orders ON users.id = orders.user_id JOIN items ON orders.id = items.order_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, JoinWithAlias) { + auto r = parser.parse("SELECT * FROM users u JOIN orders o ON u.id = o.user_id", 55); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== WHERE ========== + +TEST_F(MySQLSelectTest, WhereSimple) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + ASSERT_NE(where, nullptr); +} + +TEST_F(MySQLSelectTest, WhereComplex) { + auto r = parser.parse("SELECT * FROM users WHERE age > 18 AND status = 'active'", 56); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereIn) { + auto r = parser.parse("SELECT * FROM users WHERE id IN (1, 2, 3)", 42); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereBetween) { + auto r = parser.parse("SELECT * FROM users WHERE age BETWEEN 18 AND 65", 48); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereLike) { + auto r = parser.parse("SELECT * FROM users WHERE name LIKE '%john%'", 44); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereIsNull) { + auto r = parser.parse("SELECT * FROM users WHERE email IS NULL", 39); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereSubquery) { + auto r = parser.parse("SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)", 60); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== GROUP BY / HAVING ========== + +TEST_F(MySQLSelectTest, GroupBy) { + auto r = parser.parse("SELECT status, COUNT(*) FROM users GROUP BY status", 51); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* gb = find_child(r.ast, NodeType::NODE_GROUP_BY_CLAUSE); + ASSERT_NE(gb, nullptr); +} + +TEST_F(MySQLSelectTest, GroupByMultiple) { + auto r = parser.parse("SELECT dept, status, COUNT(*) FROM users GROUP BY dept, status", 62); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, GroupByHaving) { + auto r = parser.parse("SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5", 71); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* having = find_child(r.ast, NodeType::NODE_HAVING_CLAUSE); + ASSERT_NE(having, nullptr); +} + +// ========== ORDER BY ========== + +TEST_F(MySQLSelectTest, OrderBy) { + auto r = parser.parse("SELECT * FROM users ORDER BY name", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ob = find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE); + ASSERT_NE(ob, nullptr); +} + +TEST_F(MySQLSelectTest, OrderByDesc) { + auto r = parser.parse("SELECT * FROM users ORDER BY created_at DESC", 45); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, OrderByMultiple) { + auto r = parser.parse("SELECT * FROM users ORDER BY last_name ASC, first_name ASC", 58); + EXPECT_EQ(r.status, ParseResult::OK); + auto* ob = find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE); + ASSERT_NE(ob, nullptr); + EXPECT_EQ(child_count(ob), 2); +} + +// ========== LIMIT ========== + +TEST_F(MySQLSelectTest, Limit) { + auto r = parser.parse("SELECT * FROM users LIMIT 10", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); +} + +TEST_F(MySQLSelectTest, LimitOffset) { + auto r = parser.parse("SELECT * FROM users LIMIT 10 OFFSET 20", 38); + EXPECT_EQ(r.status, ParseResult::OK); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); + EXPECT_EQ(child_count(limit), 2); +} + +TEST_F(MySQLSelectTest, LimitCommaOffset) { + // MySQL syntax: LIMIT offset, count + auto r = parser.parse("SELECT * FROM users LIMIT 20, 10", 32); + EXPECT_EQ(r.status, ParseResult::OK); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); + EXPECT_EQ(child_count(limit), 2); +} + +// ========== FOR UPDATE / FOR SHARE ========== + +TEST_F(MySQLSelectTest, ForUpdate) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE", 44); + EXPECT_EQ(r.status, ParseResult::OK); + auto* lock = find_child(r.ast, NodeType::NODE_LOCKING_CLAUSE); + ASSERT_NE(lock, nullptr); +} + +TEST_F(MySQLSelectTest, ForShare) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR SHARE", 43); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, ForUpdateNowait) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE NOWAIT", 51); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, ForUpdateSkipLocked) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE SKIP LOCKED", 56); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== Complex queries ========== + +TEST_F(MySQLSelectTest, FullQuery) { + const char* sql = "SELECT u.id, u.name, COUNT(o.id) AS order_count " + "FROM users u " + "LEFT JOIN orders o ON u.id = o.user_id " + "WHERE u.status = 'active' " + "GROUP BY u.id, u.name " + "HAVING COUNT(o.id) > 5 " + "ORDER BY order_count DESC " + "LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_FROM_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_WHERE_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_GROUP_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_HAVING_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLSelectTest, SubqueryInFrom) { + const char* sql = "SELECT t.id FROM (SELECT id FROM users WHERE active = 1) AS t"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, MultiStatement) { + const char* sql = "SELECT 1; SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + EXPECT_TRUE(r.has_remaining()); +} + +TEST_F(MySQLSelectTest, SelectWithSemicolon) { + auto r = parser.parse("SELECT * FROM users;", 20); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== Bulk data-driven tests ========== + +struct SelectTestCase { + const char* sql; + const char* description; +}; + +static const SelectTestCase select_bulk_cases[] = { + {"SELECT 1", "literal"}, + {"SELECT 1, 2, 3", "multiple literals"}, + {"SELECT 'hello'", "string literal"}, + {"SELECT NULL", "null"}, + {"SELECT TRUE", "true"}, + {"SELECT FALSE", "false"}, + {"SELECT NOW()", "function call"}, + {"SELECT 1 + 2", "arithmetic"}, + {"SELECT *", "star"}, + {"SELECT * FROM t", "star from table"}, + {"SELECT a FROM t", "single column"}, + {"SELECT a, b, c FROM t", "multiple columns"}, + {"SELECT a AS x FROM t", "alias with AS"}, + {"SELECT t.a FROM t", "qualified column"}, + {"SELECT t.* FROM t", "qualified star"}, + {"SELECT DISTINCT a FROM t", "distinct"}, + {"SELECT ALL a FROM t", "all"}, + {"SELECT SQL_CALC_FOUND_ROWS * FROM t", "sql_calc_found_rows"}, + {"SELECT * FROM db.t", "qualified table"}, + {"SELECT * FROM t AS alias", "table alias with AS"}, + {"SELECT * FROM t alias", "table alias implicit"}, + {"SELECT * FROM t1, t2", "comma join"}, + {"SELECT * FROM t1 JOIN t2 ON t1.id = t2.id", "inner join"}, + {"SELECT * FROM t1 LEFT JOIN t2 ON t1.id = t2.id", "left join"}, + {"SELECT * FROM t1 RIGHT JOIN t2 ON t1.id = t2.id", "right join"}, + {"SELECT * FROM t1 CROSS JOIN t2", "cross join"}, + {"SELECT * FROM t1 NATURAL JOIN t2", "natural join"}, + {"SELECT * FROM t1 JOIN t2 USING (id)", "join using"}, + {"SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.id = t2.id", "left outer join"}, + {"SELECT * FROM t WHERE a = 1", "where equal"}, + {"SELECT * FROM t WHERE a > 1 AND b < 10", "where and"}, + {"SELECT * FROM t WHERE a IN (1,2,3)", "where in"}, + {"SELECT * FROM t WHERE a IS NULL", "where is null"}, + {"SELECT * FROM t WHERE a IS NOT NULL", "where is not null"}, + {"SELECT * FROM t WHERE a BETWEEN 1 AND 10", "where between"}, + {"SELECT * FROM t WHERE a LIKE '%x%'", "where like"}, + {"SELECT * FROM t WHERE a NOT IN (1,2)", "where not in"}, + {"SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)", "where exists"}, + {"SELECT a, COUNT(*) FROM t GROUP BY a", "group by"}, + {"SELECT a, b, COUNT(*) FROM t GROUP BY a, b", "group by multiple"}, + {"SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1", "having"}, + {"SELECT * FROM t ORDER BY a", "order by"}, + {"SELECT * FROM t ORDER BY a DESC", "order by desc"}, + {"SELECT * FROM t ORDER BY a ASC, b DESC", "order by multiple"}, + {"SELECT * FROM t LIMIT 10", "limit"}, + {"SELECT * FROM t LIMIT 10 OFFSET 5", "limit offset"}, + {"SELECT * FROM t LIMIT 5, 10", "limit comma"}, + {"SELECT * FROM t WHERE a = 1 FOR UPDATE", "for update"}, + {"SELECT * FROM t WHERE a = 1 FOR SHARE", "for share"}, + {"SELECT * FROM t FOR UPDATE NOWAIT", "for update nowait"}, + {"SELECT * FROM t FOR UPDATE SKIP LOCKED", "for update skip locked"}, + {"SELECT COUNT(*), SUM(a), AVG(b), MIN(c), MAX(d) FROM t", "aggregate functions"}, + {"SELECT CASE WHEN a = 1 THEN 'x' ELSE 'y' END FROM t", "case when"}, + {"SELECT * FROM (SELECT 1) AS t", "subquery in from"}, + {"SELECT * FROM t1 JOIN t2 ON t1.a = t2.a JOIN t3 ON t2.b = t3.b", "multiple joins"}, + {"SELECT a FROM t WHERE b = (SELECT MAX(b) FROM t2)", "scalar subquery in where"}, +}; + +TEST(MySQLSelectBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : select_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SELECT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== PostgreSQL SELECT ========== + +class PgSQLSelectTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLSelectTest, BasicSelect) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSelectTest, LimitOffset) { + auto r = parser.parse("SELECT * FROM users LIMIT 10 OFFSET 5", 37); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(PgSQLSelectTest, ForUpdate) { + auto r = parser.parse("SELECT * FROM users FOR UPDATE", 30); + EXPECT_EQ(r.status, ParseResult::OK); +} +``` + +- [ ] **Step 2: Build and run all tests** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: All tests pass. + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_select.cpp +git commit -m "test: add comprehensive SELECT parser tests with JOINs, subqueries, and bulk cases" +``` + +--- + +### Task 3: Verify and Clean Up + +**Files:** +- No new files — verification only + +- [ ] **Step 1: Run full test suite** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: ALL tests pass, zero warnings. + +- [ ] **Step 2: Check for compiler warnings** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all 2>&1 | grep -i warning +``` +Expected: Zero warnings (or only from Google Test internals). + +- [ ] **Step 3: Commit if any fixes were needed** + +```bash +# Only if changes were made +git add -A && git commit -m "fix: clean up warnings after SELECT parser integration" +``` + +--- + +## What's Next + +After this plan is complete: + +1. **Plan 4: Query Emitter** — AST → SQL reconstruction (parse → modify → emit) +2. **Plan 5: Prepared Statement Cache** — Binary protocol support +3. **Plan 6: Performance Benchmarks** — Validate latency targets diff --git a/docs/superpowers/plans/2026-03-24-sql-parser-foundation.md b/docs/superpowers/plans/2026-03-24-sql-parser-foundation.md new file mode 100644 index 0000000..165158c --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-sql-parser-foundation.md @@ -0,0 +1,2664 @@ +# SQL Parser Foundation Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Build the foundational parser pipeline — core types, arena allocator, tokenizer, classifier, and basic Tier 2 extractors — so that any SQL statement can be classified and key metadata extracted. + +**Architecture:** Three-layer pipeline (Tokenizer → Classifier → Extractors) with compile-time dialect templating (`Dialect::MySQL`, `Dialect::PostgreSQL`). Arena allocator for zero-copy, sub-microsecond operation. This plan covers Layers 1-3 from the spec but defers Tier 1 deep parsers (SELECT, SET), the emitter, and prepared statement cache to subsequent plans. + +**Tech Stack:** C++17, GNU Make, Google Test (header-only download for tests), Google Benchmark (for perf tests, deferred to later plan) + +**Spec:** `docs/superpowers/specs/2026-03-24-sql-parser-design.md` + +--- + +## Scope + +This plan builds: +1. Build system (new Makefile for the new parser + tests) +2. Core types (`StringRef`, `Dialect` enum, `NodeType`, `StmtType`, `TokenType`) +3. Arena allocator (block-chained, reset, max size) +4. `AstNode` (32-byte, arena-allocated, intrusive linked list) +5. `ParseResult` and `ErrorInfo` +6. Tokenizer (dialect-templated, MySQL + PostgreSQL, keyword perfect hash) +7. Classifier (switch dispatch on first token) +8. Tier 2 extractors (extract table name / schema for DML + DDL, transaction type, USE database) + +**Not in scope for this plan:** Tier 1 deep parsers (SELECT, SET), expression parser, emitter/reconstruction, prepared statement cache, benchmarks. + +--- + +## File Structure + +``` +include/sql_parser/ + common.h — StringRef, Dialect enum, StmtType, NodeType enums + arena.h — Arena class (block-chained allocator) + token.h — Token struct, TokenType enum + ast.h — AstNode struct + parse_result.h — ParseResult, ErrorInfo + tokenizer.h — Tokenizer template (declaration + inline impl) + keywords_mysql.h — MySQL keyword lookup table + keywords_pgsql.h — PostgreSQL keyword lookup table + parser.h — Parser public API + +src/sql_parser/ + arena.cpp — Arena non-inline methods + parser.cpp — Parser classifier + Tier 2 extractors + explicit instantiations + +tests/ + test_main.cpp — Google Test main() + test_arena.cpp — Arena unit tests + test_tokenizer.cpp — Tokenizer unit tests (MySQL + PostgreSQL) + test_classifier.cpp — Classifier + Tier 2 extractor tests + +Makefile.new — New build system (renamed to Makefile after old parser removal) +``` + +--- + +### Task 1: Build System Setup + +**Files:** +- Create: `Makefile.new` +- Create: `tests/test_main.cpp` + +This task sets up the new Makefile targeting the `include/sql_parser/` and `src/sql_parser/` layout, with a test target using Google Test. The old Makefile is left untouched. + +- [ ] **Step 1: Download Google Test header-only** + +Run: +```bash +mkdir -p third_party/googletest +git clone --depth 1 --branch v1.14.0 https://github.com/google/googletest.git third_party/googletest +``` + +- [ ] **Step 2: Create test_main.cpp** + +Create `tests/test_main.cpp`: +```cpp +#include + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +``` + +- [ ] **Step 3: Create Makefile.new** + +Create `Makefile.new`: +```makefile +CXX = g++ +CXXFLAGS = -std=c++17 -Wall -Wextra -g -O2 +CPPFLAGS = -I./include -I./third_party/googletest/googletest/include + +PROJECT_ROOT = . +SRC_DIR = $(PROJECT_ROOT)/src/sql_parser +INCLUDE_DIR = $(PROJECT_ROOT)/include/sql_parser +TEST_DIR = $(PROJECT_ROOT)/tests + +# Library sources +LIB_SRCS = $(SRC_DIR)/arena.cpp $(SRC_DIR)/parser.cpp +LIB_OBJS = $(LIB_SRCS:.cpp=.o) +LIB_TARGET = $(PROJECT_ROOT)/libsqlparser.a + +# Google Test library +GTEST_DIR = $(PROJECT_ROOT)/third_party/googletest/googletest +GTEST_SRC = $(GTEST_DIR)/src/gtest-all.cc +GTEST_OBJ = $(GTEST_DIR)/src/gtest-all.o +GTEST_CPPFLAGS = -I$(GTEST_DIR)/include -I$(GTEST_DIR) + +# Test sources +TEST_SRCS = $(TEST_DIR)/test_main.cpp \ + $(TEST_DIR)/test_arena.cpp \ + $(TEST_DIR)/test_tokenizer.cpp \ + $(TEST_DIR)/test_classifier.cpp +TEST_OBJS = $(TEST_SRCS:.cpp=.o) +TEST_TARGET = $(PROJECT_ROOT)/run_tests + +.PHONY: all lib test clean + +all: lib test + +lib: $(LIB_TARGET) + +$(LIB_TARGET): $(LIB_OBJS) + ar rcs $@ $^ + @echo "Built $@" + +$(SRC_DIR)/%.o: $(SRC_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) -c $< -o $@ + +# Google Test object +$(GTEST_OBJ): $(GTEST_SRC) + $(CXX) $(CXXFLAGS) $(GTEST_CPPFLAGS) -c $< -o $@ + +# Test objects +$(TEST_DIR)/%.o: $(TEST_DIR)/%.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(GTEST_CPPFLAGS) -c $< -o $@ + +test: $(TEST_TARGET) + ./$(TEST_TARGET) + +$(TEST_TARGET): $(TEST_OBJS) $(GTEST_OBJ) $(LIB_TARGET) + $(CXX) $(CXXFLAGS) -o $@ $(TEST_OBJS) $(GTEST_OBJ) -L$(PROJECT_ROOT) -lsqlparser -lpthread + +clean: + rm -f $(LIB_OBJS) $(LIB_TARGET) $(TEST_OBJS) $(GTEST_OBJ) $(TEST_TARGET) + @echo "Cleaned." +``` + +- [ ] **Step 4: Create directory structure** + +Run: +```bash +mkdir -p include/sql_parser src/sql_parser +``` + +- [ ] **Step 5: Commit** + +```bash +git add Makefile.new tests/test_main.cpp third_party/ +git commit -m "feat: add new build system and test infrastructure for sql_parser" +``` + +--- + +### Task 2: Core Types — StringRef, Enums + +**Files:** +- Create: `include/sql_parser/common.h` + +- [ ] **Step 1: Write common.h** + +Create `include/sql_parser/common.h`: +```cpp +#ifndef SQL_PARSER_COMMON_H +#define SQL_PARSER_COMMON_H + +#include +#include +#include + +namespace sql_parser { + +// -- Dialect -- + +enum class Dialect : uint8_t { + MySQL, + PostgreSQL +}; + +// -- StringRef: zero-copy view into input buffer -- + +struct StringRef { + const char* ptr = nullptr; + uint32_t len = 0; + + bool empty() const { return len == 0; } + + bool equals_ci(const char* s, uint32_t slen) const { + if (len != slen) return false; + for (uint32_t i = 0; i < len; ++i) { + char a = ptr[i]; + char b = s[i]; + // fast ASCII tolower + if (a >= 'A' && a <= 'Z') a += 32; + if (b >= 'A' && b <= 'Z') b += 32; + if (a != b) return false; + } + return true; + } + + bool operator==(const StringRef& o) const { + return len == o.len && (ptr == o.ptr || std::memcmp(ptr, o.ptr, len) == 0); + } + bool operator!=(const StringRef& o) const { return !(*this == o); } +}; +static_assert(std::is_trivially_copyable_v); + +// Case-insensitive comparison for keyword lookup (used by keyword tables) +inline int ci_cmp(const char* a, uint32_t alen, const char* b, uint8_t blen) { + uint32_t minlen = alen < blen ? alen : blen; + for (uint32_t i = 0; i < minlen; ++i) { + char ca = a[i]; + char cb = b[i]; + if (ca >= 'a' && ca <= 'z') ca -= 32; + if (cb >= 'a' && cb <= 'z') cb -= 32; + if (ca != cb) return ca < cb ? -1 : 1; + } + if (alen < blen) return -1; + if (alen > blen) return 1; + return 0; +} + +// -- Statement type (always set, even for PARTIAL/ERROR) -- + +enum class StmtType : uint8_t { + UNKNOWN = 0, + SELECT, + INSERT, + UPDATE, + DELETE_STMT, // avoid clash with delete keyword + REPLACE, + SET, + USE, + SHOW, + BEGIN, + START_TRANSACTION, + COMMIT, + ROLLBACK, + SAVEPOINT, + PREPARE, + EXECUTE, + DEALLOCATE, + CREATE, + ALTER, + DROP, + TRUNCATE, + GRANT, + REVOKE, + LOCK, + UNLOCK, + LOAD_DATA, + RESET, // PostgreSQL RESET +}; + +// -- AST node types -- + +enum class NodeType : uint16_t { + NODE_UNKNOWN = 0, + + // Tier 2 lightweight nodes + NODE_STATEMENT, // root wrapper with StmtType in flags + NODE_TABLE_REF, // table name + NODE_SCHEMA_REF, // schema/database name + NODE_IDENTIFIER, + NODE_QUALIFIED_NAME, // schema.table or table.column + + // Tier 1 nodes (SELECT) — defined here, used in future plan + NODE_SELECT_STMT, + NODE_SELECT_OPTIONS, + NODE_SELECT_ITEM_LIST, + NODE_SELECT_ITEM, + NODE_FROM_CLAUSE, + NODE_JOIN_CLAUSE, + NODE_WHERE_CLAUSE, + NODE_GROUP_BY_CLAUSE, + NODE_HAVING_CLAUSE, + NODE_ORDER_BY_CLAUSE, + NODE_ORDER_BY_ITEM, + NODE_LIMIT_CLAUSE, + NODE_LOCKING_CLAUSE, + NODE_INTO_CLAUSE, + NODE_ALIAS, + + // Tier 1 nodes (SET) — defined here, used in future plan + NODE_SET_STMT, + NODE_SET_NAMES, + NODE_SET_CHARSET, + NODE_SET_TRANSACTION, + NODE_VAR_ASSIGNMENT, + NODE_VAR_TARGET, + + // Expression nodes — defined here, used in future plan + NODE_EXPRESSION, + NODE_BINARY_OP, + NODE_UNARY_OP, + NODE_FUNCTION_CALL, + NODE_LITERAL_INT, + NODE_LITERAL_FLOAT, + NODE_LITERAL_STRING, + NODE_LITERAL_NULL, + NODE_PLACEHOLDER, // ? or $N + NODE_SUBQUERY, + NODE_COLUMN_REF, + NODE_ASTERISK, + NODE_IS_NULL, + NODE_IS_NOT_NULL, + NODE_BETWEEN, + NODE_IN_LIST, + NODE_CASE_WHEN, +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_COMMON_H +``` + +- [ ] **Step 2: Write a compile test** + +This is a header-only file. Verify it compiles: + +Run: +```bash +echo '#include "sql_parser/common.h"' > /tmp/test_common.cpp && \ +echo 'int main() { sql_parser::StringRef s; return s.empty() ? 0 : 1; }' >> /tmp/test_common.cpp && \ +g++ -std=c++17 -I./include -c /tmp/test_common.cpp -o /dev/null && echo "OK" +``` +Expected: `OK` + +- [ ] **Step 3: Commit** + +```bash +git add include/sql_parser/common.h +git commit -m "feat: add core types — StringRef, Dialect, StmtType, NodeType enums" +``` + +--- + +### Task 3: Arena Allocator + +**Files:** +- Create: `include/sql_parser/arena.h` +- Create: `src/sql_parser/arena.cpp` +- Create: `tests/test_arena.cpp` + +- [ ] **Step 1: Write the failing test** + +Create `tests/test_arena.cpp`: +```cpp +#include +#include "sql_parser/arena.h" + +using namespace sql_parser; + +TEST(ArenaTest, AllocateAndReset) { + Arena arena(4096); // 4KB block + void* p1 = arena.allocate(64); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(64); + ASSERT_NE(p2, nullptr); + EXPECT_NE(p1, p2); + + arena.reset(); + // After reset, next allocation reuses the same block + void* p3 = arena.allocate(64); + ASSERT_NE(p3, nullptr); + EXPECT_EQ(p1, p3); // same address — block was reused +} + +TEST(ArenaTest, AllocateAligned) { + Arena arena(4096); + void* p1 = arena.allocate(1); // 1 byte + void* p2 = arena.allocate(8); // should be 8-byte aligned + EXPECT_EQ(reinterpret_cast(p2) % 8, 0u); +} + +TEST(ArenaTest, OverflowToNewBlock) { + Arena arena(128); // small block + // Allocate more than one block's worth + void* p1 = arena.allocate(100); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(100); // should overflow to second block + ASSERT_NE(p2, nullptr); + EXPECT_NE(p1, p2); +} + +TEST(ArenaTest, ResetFreesOverflowBlocks) { + Arena arena(128); + arena.allocate(100); + arena.allocate(100); // overflow block allocated + arena.reset(); + // First allocation after reset should be in the primary block + void* p = arena.allocate(64); + ASSERT_NE(p, nullptr); +} + +TEST(ArenaTest, MaxSizeEnforced) { + Arena arena(128, 256); // 128 block size, 256 max total + void* p1 = arena.allocate(100); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(100); + ASSERT_NE(p2, nullptr); + // Third allocation exceeds 256 max + void* p3 = arena.allocate(100); + EXPECT_EQ(p3, nullptr); +} + +TEST(ArenaTest, AllocateTyped) { + Arena arena(4096); + + struct TestStruct { + int a; + double b; + }; + + TestStruct* ts = arena.allocate_typed(); + ASSERT_NE(ts, nullptr); + ts->a = 42; + ts->b = 3.14; + EXPECT_EQ(ts->a, 42); + EXPECT_DOUBLE_EQ(ts->b, 3.14); +} + +TEST(ArenaTest, AllocateString) { + Arena arena(4096); + const char* src = "hello world"; + StringRef ref = arena.allocate_string(src, 11); + EXPECT_EQ(ref.len, 11u); + EXPECT_EQ(std::memcmp(ref.ptr, "hello world", 11), 0); + // Should be a copy, not the same pointer + EXPECT_NE(ref.ptr, src); +} +``` + +- [ ] **Step 2: Write arena.h** + +Create `include/sql_parser/arena.h`: +```cpp +#ifndef SQL_PARSER_ARENA_H +#define SQL_PARSER_ARENA_H + +#include "sql_parser/common.h" +#include +#include +#include +#include + +namespace sql_parser { + +class Arena { +public: + explicit Arena(size_t block_size = 65536, size_t max_size = 1048576); + ~Arena(); + + // Non-copyable, non-movable + Arena(const Arena&) = delete; + Arena& operator=(const Arena&) = delete; + Arena(Arena&&) = delete; + Arena& operator=(Arena&&) = delete; + + // Allocate raw bytes (8-byte aligned). Returns nullptr if max_size exceeded. + void* allocate(size_t bytes); + + // Allocate and default-construct a typed object. + template + T* allocate_typed() { + void* mem = allocate(sizeof(T)); + if (!mem) return nullptr; + return new (mem) T{}; + } + + // Copy a string into the arena and return a StringRef to the copy. + StringRef allocate_string(const char* src, uint32_t len); + + // Reset: rewind primary block, free overflow blocks. O(1) in common case. + void reset(); + + // Current total bytes allocated (across all blocks). + size_t bytes_used() const; + +private: + struct Block { + Block* next; + size_t capacity; + size_t used; + // Data follows immediately after this header. + char* data() { return reinterpret_cast(this) + sizeof(Block); } + }; + + Block* allocate_block(size_t capacity); + + Block* primary_; + Block* current_; + size_t block_size_; + size_t max_size_; + size_t total_allocated_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_ARENA_H +``` + +- [ ] **Step 3: Write arena.cpp** + +Create `src/sql_parser/arena.cpp`: +```cpp +#include "sql_parser/arena.h" +#include + +namespace sql_parser { + +Arena::Block* Arena::allocate_block(size_t capacity) { + void* mem = std::malloc(sizeof(Block) + capacity); + if (!mem) return nullptr; + Block* block = static_cast(mem); + block->next = nullptr; + block->capacity = capacity; + block->used = 0; + return block; +} + +Arena::Arena(size_t block_size, size_t max_size) + : block_size_(block_size), max_size_(max_size), total_allocated_(0) { + primary_ = allocate_block(block_size_); + current_ = primary_; + total_allocated_ = block_size_; +} + +Arena::~Arena() { + Block* b = primary_; + while (b) { + Block* next = b->next; + std::free(b); + b = next; + } +} + +void* Arena::allocate(size_t bytes) { + // Align to 8 bytes + bytes = (bytes + 7) & ~size_t(7); + + // Try current block + if (current_->used + bytes <= current_->capacity) { + void* ptr = current_->data() + current_->used; + current_->used += bytes; + return ptr; + } + + // Need a new block + size_t new_cap = (bytes > block_size_) ? bytes : block_size_; + if (total_allocated_ + new_cap > max_size_) { + return nullptr; // max size exceeded + } + + Block* new_block = allocate_block(new_cap); + if (!new_block) return nullptr; + + current_->next = new_block; + current_ = new_block; + total_allocated_ += new_cap; + + void* ptr = current_->data() + current_->used; + current_->used += bytes; + return ptr; +} + +StringRef Arena::allocate_string(const char* src, uint32_t len) { + void* mem = allocate(len); + if (!mem) return StringRef{nullptr, 0}; + std::memcpy(mem, src, len); + return StringRef{static_cast(mem), len}; +} + +void Arena::reset() { + // Free overflow blocks + Block* b = primary_->next; + while (b) { + Block* next = b->next; + std::free(b); + b = next; + } + primary_->next = nullptr; + primary_->used = 0; + current_ = primary_; + total_allocated_ = block_size_; +} + +size_t Arena::bytes_used() const { + size_t used = 0; + const Block* b = primary_; + while (b) { + used += b->used; + b = b->next; + } + return used; +} + +} // namespace sql_parser +``` + +- [ ] **Step 4: Build and run tests** + +Run: +```bash +make -f Makefile.new test +``` +Expected: All ArenaTest tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add include/sql_parser/arena.h src/sql_parser/arena.cpp tests/test_arena.cpp +git commit -m "feat: add arena allocator with block chaining and max size" +``` + +--- + +### Task 4: AstNode and ParseResult + +**Files:** +- Create: `include/sql_parser/ast.h` +- Create: `include/sql_parser/parse_result.h` + +- [ ] **Step 1: Write ast.h** + +Create `include/sql_parser/ast.h`: +```cpp +#ifndef SQL_PARSER_AST_H +#define SQL_PARSER_AST_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include +#include + +namespace sql_parser { + +struct AstNode { + AstNode* first_child; + AstNode* next_sibling; + const char* value_ptr; + uint32_t value_len; + NodeType type; + uint16_t flags; + + // Convenience accessors + StringRef value() const { return StringRef{value_ptr, value_len}; } + + void set_value(StringRef ref) { + value_ptr = ref.ptr; + value_len = ref.len; + } + + // Append child to end of child list + void add_child(AstNode* child) { + if (!child) return; + if (!first_child) { + first_child = child; + return; + } + AstNode* last = first_child; + while (last->next_sibling) last = last->next_sibling; + last->next_sibling = child; + } +}; +static_assert(sizeof(AstNode) == 32, "AstNode must be 32 bytes"); +static_assert(std::is_trivially_copyable_v); + +// Factory: allocate an AstNode from the arena +inline AstNode* make_node(Arena& arena, NodeType type, StringRef value = {}, + uint16_t flags = 0) { + AstNode* node = arena.allocate_typed(); + if (!node) return nullptr; + node->type = type; + node->flags = flags; + node->value_ptr = value.ptr; + node->value_len = value.len; + return node; +} + +} // namespace sql_parser + +#endif // SQL_PARSER_AST_H +``` + +- [ ] **Step 2: Write parse_result.h** + +Create `include/sql_parser/parse_result.h`: +```cpp +#ifndef SQL_PARSER_PARSE_RESULT_H +#define SQL_PARSER_PARSE_RESULT_H + +#include "sql_parser/common.h" +#include "sql_parser/ast.h" + +namespace sql_parser { + +struct ErrorInfo { + uint32_t offset = 0; + StringRef message; +}; + +struct ParseResult { + enum Status : uint8_t { OK = 0, PARTIAL, ERROR }; + + Status status = ERROR; + StmtType stmt_type = StmtType::UNKNOWN; + AstNode* ast = nullptr; + ErrorInfo error; + StringRef remaining; // unparsed input after semicolon + + // Tier 2 extracted metadata + StringRef table_name; + StringRef schema_name; + StringRef database_name; // for USE statements + + bool ok() const { return status == OK; } + bool has_remaining() const { return !remaining.empty(); } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_PARSE_RESULT_H +``` + +- [ ] **Step 3: Compile test** + +Run: +```bash +echo '#include "sql_parser/parse_result.h"' > /tmp/test_pr.cpp && \ +echo 'int main() { sql_parser::ParseResult r; return r.ok() ? 0 : 1; }' >> /tmp/test_pr.cpp && \ +g++ -std=c++17 -I./include -c /tmp/test_pr.cpp -o /dev/null && echo "OK" +``` +Expected: `OK` + +- [ ] **Step 4: Commit** + +```bash +git add include/sql_parser/ast.h include/sql_parser/parse_result.h +git commit -m "feat: add AstNode (32-byte) and ParseResult structs" +``` + +--- + +### Task 5: Token Types and Keyword Tables + +**Files:** +- Create: `include/sql_parser/token.h` +- Create: `include/sql_parser/keywords_mysql.h` +- Create: `include/sql_parser/keywords_pgsql.h` + +- [ ] **Step 1: Write token.h** + +Create `include/sql_parser/token.h`: +```cpp +#ifndef SQL_PARSER_TOKEN_H +#define SQL_PARSER_TOKEN_H + +#include "sql_parser/common.h" +#include + +namespace sql_parser { + +enum class TokenType : uint16_t { + // End / error + TK_EOF = 0, + TK_ERROR, + + // Literals + TK_IDENTIFIER, + TK_INTEGER, + TK_FLOAT, + TK_STRING, + + // Punctuation + TK_LPAREN, // ( + TK_RPAREN, // ) + TK_COMMA, // , + TK_SEMICOLON, // ; + TK_DOT, // . + TK_ASTERISK, // * + TK_PLUS, // + + TK_MINUS, // - + TK_SLASH, // / + TK_PERCENT, // % + TK_EQUAL, // = + TK_NOT_EQUAL, // != or <> + TK_LESS, // < + TK_GREATER, // > + TK_LESS_EQUAL, // <= + TK_GREATER_EQUAL, // >= + TK_AMPERSAND, // & + TK_PIPE, // | + TK_CARET, // ^ + TK_TILDE, // ~ + TK_EXCLAIM, // ! + TK_COLON, // : + TK_QUESTION, // ? + TK_AT, // @ + TK_DOUBLE_AT, // @@ + TK_HASH, // # + + // MySQL-specific operators + TK_COLON_EQUAL, // := + TK_DOUBLE_PIPE, // || (also PgSQL string concat) + + // PostgreSQL-specific operators + TK_DOUBLE_COLON, // :: + TK_DOLLAR_NUM, // $1, $2 etc. (prepared stmt placeholder) + + // Keywords — DML + TK_SELECT, + TK_INSERT, + TK_UPDATE, + TK_DELETE, + TK_REPLACE, + TK_FROM, + TK_WHERE, + TK_SET, + TK_INTO, + TK_VALUES, + TK_AS, + TK_ON, + TK_USING, + + // Keywords — clauses + TK_JOIN, + TK_INNER, + TK_LEFT, + TK_RIGHT, + TK_FULL, + TK_OUTER, + TK_CROSS, + TK_NATURAL, + TK_ORDER, + TK_BY, + TK_GROUP, + TK_HAVING, + TK_LIMIT, + TK_OFFSET, + TK_FETCH, + TK_ASC, + TK_DESC, + TK_DISTINCT, + TK_ALL, + + // Keywords — logical / comparison + TK_AND, + TK_OR, + TK_NOT, + TK_IS, + TK_NULL, + TK_IN, + TK_BETWEEN, + TK_LIKE, + TK_EXISTS, + TK_CASE, + TK_WHEN, + TK_THEN, + TK_ELSE, + TK_END, + TK_TRUE, + TK_FALSE, + + // Keywords — SET + TK_NAMES, + TK_CHARACTER, + TK_CHARSET, + TK_COLLATE, + TK_GLOBAL, + TK_SESSION, + TK_LOCAL, + TK_PERSIST, + TK_DEFAULT, + TK_TRANSACTION, + TK_ISOLATION, + TK_LEVEL, + TK_READ, + TK_WRITE, + TK_ONLY, + TK_COMMITTED, + TK_UNCOMMITTED, + TK_REPEATABLE, + TK_SERIALIZABLE, + TK_TO, + + // Keywords — DDL + TK_CREATE, + TK_ALTER, + TK_DROP, + TK_TRUNCATE, + TK_TABLE, + TK_INDEX, + TK_VIEW, + TK_DATABASE, + TK_SCHEMA, + TK_IF, + + // Keywords — transaction + TK_BEGIN, + TK_START, + TK_COMMIT, + TK_ROLLBACK, + TK_SAVEPOINT, + + // Keywords — other + TK_USE, + TK_SHOW, + TK_PREPARE, + TK_EXECUTE, + TK_DEALLOCATE, + TK_GRANT, + TK_REVOKE, + TK_LOCK, + TK_UNLOCK, + TK_LOAD, + TK_DATA, + TK_FOR, + TK_SHARE, + TK_NOWAIT, + TK_SKIP, + TK_LOCKED, + TK_OUTFILE, + TK_DUMPFILE, + TK_IGNORE, + TK_LOW_PRIORITY, + TK_QUICK, + TK_RESET, + + // MySQL-specific keywords + TK_SQL_CALC_FOUND_ROWS, + + // Aggregate / functions (recognized as keywords for fast dispatch) + TK_COUNT, + TK_SUM, + TK_AVG, + TK_MIN, + TK_MAX, +}; + +struct Token { + TokenType type = TokenType::TK_EOF; + StringRef text; + uint32_t offset = 0; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TOKEN_H +``` + +- [ ] **Step 2: Write keywords_mysql.h** + +Create `include/sql_parser/keywords_mysql.h`: + +This uses a sorted array + binary search for keyword lookup. A perfect hash can be added later as an optimization. + +```cpp +#ifndef SQL_PARSER_KEYWORDS_MYSQL_H +#define SQL_PARSER_KEYWORDS_MYSQL_H + +#include "sql_parser/token.h" +#include +#include + +namespace sql_parser { +namespace mysql_keywords { + +struct KeywordEntry { + const char* text; + uint8_t len; + TokenType token; +}; + +// Sorted by text (case-insensitive) for binary search. +// Must stay sorted when adding entries. +inline constexpr KeywordEntry KEYWORDS[] = { + {"ALL", 3, TokenType::TK_ALL}, + {"ALTER", 5, TokenType::TK_ALTER}, + {"AND", 3, TokenType::TK_AND}, + {"AS", 2, TokenType::TK_AS}, + {"ASC", 3, TokenType::TK_ASC}, + {"AVG", 3, TokenType::TK_AVG}, + {"BEGIN", 5, TokenType::TK_BEGIN}, + {"BETWEEN", 7, TokenType::TK_BETWEEN}, + {"BY", 2, TokenType::TK_BY}, + {"CASE", 4, TokenType::TK_CASE}, + {"CHARACTER", 9, TokenType::TK_CHARACTER}, + {"CHARSET", 7, TokenType::TK_CHARSET}, + {"COLLATE", 7, TokenType::TK_COLLATE}, + {"COMMIT", 6, TokenType::TK_COMMIT}, + {"COMMITTED", 9, TokenType::TK_COMMITTED}, + {"COUNT", 5, TokenType::TK_COUNT}, + {"CREATE", 6, TokenType::TK_CREATE}, + {"CROSS", 5, TokenType::TK_CROSS}, + {"DATA", 4, TokenType::TK_DATA}, + {"DATABASE", 8, TokenType::TK_DATABASE}, + {"DEALLOCATE", 10, TokenType::TK_DEALLOCATE}, + {"DEFAULT", 7, TokenType::TK_DEFAULT}, + {"DELETE", 6, TokenType::TK_DELETE}, + {"DESC", 4, TokenType::TK_DESC}, + {"DISTINCT", 8, TokenType::TK_DISTINCT}, + {"DROP", 4, TokenType::TK_DROP}, + {"DUMPFILE", 8, TokenType::TK_DUMPFILE}, + {"ELSE", 4, TokenType::TK_ELSE}, + {"END", 3, TokenType::TK_END}, + {"EXECUTE", 7, TokenType::TK_EXECUTE}, + {"EXISTS", 6, TokenType::TK_EXISTS}, + {"FALSE", 5, TokenType::TK_FALSE}, + {"FETCH", 5, TokenType::TK_FETCH}, + {"FOR", 3, TokenType::TK_FOR}, + {"FROM", 4, TokenType::TK_FROM}, + {"FULL", 4, TokenType::TK_FULL}, + {"GLOBAL", 6, TokenType::TK_GLOBAL}, + {"GRANT", 5, TokenType::TK_GRANT}, + {"GROUP", 5, TokenType::TK_GROUP}, + {"HAVING", 6, TokenType::TK_HAVING}, + {"IF", 2, TokenType::TK_IF}, + {"IGNORE", 6, TokenType::TK_IGNORE}, + {"IN", 2, TokenType::TK_IN}, + {"INDEX", 5, TokenType::TK_INDEX}, + {"INNER", 5, TokenType::TK_INNER}, + {"INSERT", 6, TokenType::TK_INSERT}, + {"INTO", 4, TokenType::TK_INTO}, + {"IS", 2, TokenType::TK_IS}, + {"ISOLATION", 9, TokenType::TK_ISOLATION}, + {"JOIN", 4, TokenType::TK_JOIN}, + {"LEFT", 4, TokenType::TK_LEFT}, + {"LEVEL", 5, TokenType::TK_LEVEL}, + {"LIKE", 4, TokenType::TK_LIKE}, + {"LIMIT", 5, TokenType::TK_LIMIT}, + {"LOAD", 4, TokenType::TK_LOAD}, + {"LOCAL", 5, TokenType::TK_LOCAL}, + {"LOCK", 4, TokenType::TK_LOCK}, + {"LOCKED", 6, TokenType::TK_LOCKED}, + {"LOW_PRIORITY", 12, TokenType::TK_LOW_PRIORITY}, + {"MAX", 3, TokenType::TK_MAX}, + {"MIN", 3, TokenType::TK_MIN}, + {"NAMES", 5, TokenType::TK_NAMES}, + {"NATURAL", 7, TokenType::TK_NATURAL}, + {"NOT", 3, TokenType::TK_NOT}, + {"NOWAIT", 6, TokenType::TK_NOWAIT}, + {"NULL", 4, TokenType::TK_NULL}, + {"OFFSET", 6, TokenType::TK_OFFSET}, + {"ON", 2, TokenType::TK_ON}, + {"ONLY", 4, TokenType::TK_ONLY}, + {"OR", 2, TokenType::TK_OR}, + {"ORDER", 5, TokenType::TK_ORDER}, + {"OUTER", 5, TokenType::TK_OUTER}, + {"OUTFILE", 7, TokenType::TK_OUTFILE}, + {"PERSIST", 7, TokenType::TK_PERSIST}, + {"PREPARE", 7, TokenType::TK_PREPARE}, + {"QUICK", 5, TokenType::TK_QUICK}, + {"READ", 4, TokenType::TK_READ}, + {"REPEATABLE", 10, TokenType::TK_REPEATABLE}, + {"REPLACE", 7, TokenType::TK_REPLACE}, + {"RESET", 5, TokenType::TK_RESET}, + {"REVOKE", 6, TokenType::TK_REVOKE}, + {"RIGHT", 5, TokenType::TK_RIGHT}, + {"ROLLBACK", 8, TokenType::TK_ROLLBACK}, + {"SAVEPOINT", 9, TokenType::TK_SAVEPOINT}, + {"SCHEMA", 6, TokenType::TK_SCHEMA}, + {"SELECT", 6, TokenType::TK_SELECT}, + {"SERIALIZABLE", 12, TokenType::TK_SERIALIZABLE}, + {"SESSION", 7, TokenType::TK_SESSION}, + {"SET", 3, TokenType::TK_SET}, + {"SHARE", 5, TokenType::TK_SHARE}, + {"SHOW", 4, TokenType::TK_SHOW}, + {"SKIP", 4, TokenType::TK_SKIP}, + {"SQL_CALC_FOUND_ROWS", 19, TokenType::TK_SQL_CALC_FOUND_ROWS}, + {"START", 5, TokenType::TK_START}, + {"SUM", 3, TokenType::TK_SUM}, + {"TABLE", 5, TokenType::TK_TABLE}, + {"THEN", 4, TokenType::TK_THEN}, + {"TO", 2, TokenType::TK_TO}, + {"TRANSACTION", 11, TokenType::TK_TRANSACTION}, + {"TRUE", 4, TokenType::TK_TRUE}, + {"TRUNCATE", 8, TokenType::TK_TRUNCATE}, + {"UNCOMMITTED", 11, TokenType::TK_UNCOMMITTED}, + {"UNLOCK", 6, TokenType::TK_UNLOCK}, + {"UPDATE", 6, TokenType::TK_UPDATE}, + {"USE", 3, TokenType::TK_USE}, + {"USING", 5, TokenType::TK_USING}, + {"VALUES", 6, TokenType::TK_VALUES}, + {"VIEW", 4, TokenType::TK_VIEW}, + {"WHEN", 4, TokenType::TK_WHEN}, + {"WHERE", 5, TokenType::TK_WHERE}, + {"WRITE", 5, TokenType::TK_WRITE}, +}; + +inline constexpr size_t KEYWORD_COUNT = sizeof(KEYWORDS) / sizeof(KEYWORDS[0]); + +// Returns TK_IDENTIFIER if not a keyword. +inline TokenType lookup(const char* text, uint32_t len) { + size_t lo = 0, hi = KEYWORD_COUNT; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + int cmp = sql_parser::ci_cmp(text, len, KEYWORDS[mid].text, KEYWORDS[mid].len); + if (cmp == 0) return KEYWORDS[mid].token; + if (cmp < 0) hi = mid; + else lo = mid + 1; + } + return TokenType::TK_IDENTIFIER; +} + +} // namespace mysql_keywords +} // namespace sql_parser + +#endif // SQL_PARSER_KEYWORDS_MYSQL_H +``` + +- [ ] **Step 3: Write keywords_pgsql.h** + +Create `include/sql_parser/keywords_pgsql.h`: + +Same structure, with PostgreSQL-specific keyword set. Uses the same lookup pattern. + +```cpp +#ifndef SQL_PARSER_KEYWORDS_PGSQL_H +#define SQL_PARSER_KEYWORDS_PGSQL_H + +#include "sql_parser/token.h" + +namespace sql_parser { +namespace pgsql_keywords { + +struct KeywordEntry { + const char* text; + uint8_t len; + TokenType token; +}; + +// Sorted by text (case-insensitive) for binary search. +// PostgreSQL shares most keywords with MySQL; main differences: +// - No SQL_CALC_FOUND_ROWS +// - RESET is a first-class keyword (not just for SET) +// - TO is used in SET x TO y +inline constexpr KeywordEntry KEYWORDS[] = { + {"ALL", 3, TokenType::TK_ALL}, + {"ALTER", 5, TokenType::TK_ALTER}, + {"AND", 3, TokenType::TK_AND}, + {"AS", 2, TokenType::TK_AS}, + {"ASC", 3, TokenType::TK_ASC}, + {"AVG", 3, TokenType::TK_AVG}, + {"BEGIN", 5, TokenType::TK_BEGIN}, + {"BETWEEN", 7, TokenType::TK_BETWEEN}, + {"BY", 2, TokenType::TK_BY}, + {"CASE", 4, TokenType::TK_CASE}, + {"CHARACTER", 9, TokenType::TK_CHARACTER}, + {"COLLATE", 7, TokenType::TK_COLLATE}, + {"COMMIT", 6, TokenType::TK_COMMIT}, + {"COMMITTED", 9, TokenType::TK_COMMITTED}, + {"COUNT", 5, TokenType::TK_COUNT}, + {"CREATE", 6, TokenType::TK_CREATE}, + {"CROSS", 5, TokenType::TK_CROSS}, + {"DATA", 4, TokenType::TK_DATA}, + {"DATABASE", 8, TokenType::TK_DATABASE}, + {"DEALLOCATE", 10, TokenType::TK_DEALLOCATE}, + {"DEFAULT", 7, TokenType::TK_DEFAULT}, + {"DELETE", 6, TokenType::TK_DELETE}, + {"DESC", 4, TokenType::TK_DESC}, + {"DISTINCT", 8, TokenType::TK_DISTINCT}, + {"DROP", 4, TokenType::TK_DROP}, + {"ELSE", 4, TokenType::TK_ELSE}, + {"END", 3, TokenType::TK_END}, + {"EXECUTE", 7, TokenType::TK_EXECUTE}, + {"EXISTS", 6, TokenType::TK_EXISTS}, + {"FALSE", 5, TokenType::TK_FALSE}, + {"FETCH", 5, TokenType::TK_FETCH}, + {"FOR", 3, TokenType::TK_FOR}, + {"FROM", 4, TokenType::TK_FROM}, + {"FULL", 4, TokenType::TK_FULL}, + {"GRANT", 5, TokenType::TK_GRANT}, + {"GROUP", 5, TokenType::TK_GROUP}, + {"HAVING", 6, TokenType::TK_HAVING}, + {"IF", 2, TokenType::TK_IF}, + {"IN", 2, TokenType::TK_IN}, + {"INDEX", 5, TokenType::TK_INDEX}, + {"INNER", 5, TokenType::TK_INNER}, + {"INSERT", 6, TokenType::TK_INSERT}, + {"INTO", 4, TokenType::TK_INTO}, + {"IS", 2, TokenType::TK_IS}, + {"ISOLATION", 9, TokenType::TK_ISOLATION}, + {"JOIN", 4, TokenType::TK_JOIN}, + {"LEFT", 4, TokenType::TK_LEFT}, + {"LEVEL", 5, TokenType::TK_LEVEL}, + {"LIKE", 4, TokenType::TK_LIKE}, + {"LIMIT", 5, TokenType::TK_LIMIT}, + {"LOAD", 4, TokenType::TK_LOAD}, + {"LOCAL", 5, TokenType::TK_LOCAL}, + {"LOCK", 4, TokenType::TK_LOCK}, + {"MAX", 3, TokenType::TK_MAX}, + {"MIN", 3, TokenType::TK_MIN}, + {"NAMES", 5, TokenType::TK_NAMES}, + {"NATURAL", 7, TokenType::TK_NATURAL}, + {"NOT", 3, TokenType::TK_NOT}, + {"NULL", 4, TokenType::TK_NULL}, + {"OFFSET", 6, TokenType::TK_OFFSET}, + {"ON", 2, TokenType::TK_ON}, + {"ONLY", 4, TokenType::TK_ONLY}, + {"OR", 2, TokenType::TK_OR}, + {"ORDER", 5, TokenType::TK_ORDER}, + {"OUTER", 5, TokenType::TK_OUTER}, + {"PREPARE", 7, TokenType::TK_PREPARE}, + {"READ", 4, TokenType::TK_READ}, + {"REPEATABLE", 10, TokenType::TK_REPEATABLE}, + {"RESET", 5, TokenType::TK_RESET}, + {"REVOKE", 6, TokenType::TK_REVOKE}, + {"RIGHT", 5, TokenType::TK_RIGHT}, + {"ROLLBACK", 8, TokenType::TK_ROLLBACK}, + {"SAVEPOINT", 9, TokenType::TK_SAVEPOINT}, + {"SCHEMA", 6, TokenType::TK_SCHEMA}, + {"SELECT", 6, TokenType::TK_SELECT}, + {"SERIALIZABLE", 12, TokenType::TK_SERIALIZABLE}, + {"SESSION", 7, TokenType::TK_SESSION}, + {"SET", 3, TokenType::TK_SET}, + {"SHARE", 5, TokenType::TK_SHARE}, + {"SHOW", 4, TokenType::TK_SHOW}, + {"START", 5, TokenType::TK_START}, + {"SUM", 3, TokenType::TK_SUM}, + {"TABLE", 5, TokenType::TK_TABLE}, + {"THEN", 4, TokenType::TK_THEN}, + {"TO", 2, TokenType::TK_TO}, + {"TRANSACTION", 11, TokenType::TK_TRANSACTION}, + {"TRUE", 4, TokenType::TK_TRUE}, + {"TRUNCATE", 8, TokenType::TK_TRUNCATE}, + {"UNCOMMITTED", 11, TokenType::TK_UNCOMMITTED}, + {"UNLOCK", 6, TokenType::TK_UNLOCK}, + {"UPDATE", 6, TokenType::TK_UPDATE}, + {"USE", 3, TokenType::TK_USE}, + {"USING", 5, TokenType::TK_USING}, + {"VALUES", 6, TokenType::TK_VALUES}, + {"VIEW", 4, TokenType::TK_VIEW}, + {"WHEN", 4, TokenType::TK_WHEN}, + {"WHERE", 5, TokenType::TK_WHERE}, + {"WRITE", 5, TokenType::TK_WRITE}, +}; + +inline constexpr size_t KEYWORD_COUNT = sizeof(KEYWORDS) / sizeof(KEYWORDS[0]); + +// Uses ci_cmp from common.h +inline TokenType lookup(const char* text, uint32_t len) { + size_t lo = 0, hi = KEYWORD_COUNT; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + int cmp = sql_parser::ci_cmp(text, len, KEYWORDS[mid].text, KEYWORDS[mid].len); + if (cmp == 0) return KEYWORDS[mid].token; + if (cmp < 0) hi = mid; + else lo = mid + 1; + } + return TokenType::TK_IDENTIFIER; +} + +} // namespace pgsql_keywords +} // namespace sql_parser + +#endif // SQL_PARSER_KEYWORDS_PGSQL_H +``` + +- [ ] **Step 4: Compile test** + +Run: +```bash +echo '#include "sql_parser/keywords_mysql.h"' > /tmp/test_kw.cpp && \ +echo '#include "sql_parser/keywords_pgsql.h"' >> /tmp/test_kw.cpp && \ +echo '#include ' >> /tmp/test_kw.cpp && \ +echo 'int main() {' >> /tmp/test_kw.cpp && \ +echo ' assert(sql_parser::mysql_keywords::lookup("SELECT", 6) == sql_parser::TokenType::TK_SELECT);' >> /tmp/test_kw.cpp && \ +echo ' assert(sql_parser::mysql_keywords::lookup("select", 6) == sql_parser::TokenType::TK_SELECT);' >> /tmp/test_kw.cpp && \ +echo ' assert(sql_parser::mysql_keywords::lookup("foobar", 6) == sql_parser::TokenType::TK_IDENTIFIER);' >> /tmp/test_kw.cpp && \ +echo ' return 0; }' >> /tmp/test_kw.cpp && \ +g++ -std=c++17 -I./include /tmp/test_kw.cpp -o /dev/null && echo "OK" +``` +Expected: `OK` + +- [ ] **Step 5: Commit** + +```bash +git add include/sql_parser/token.h include/sql_parser/keywords_mysql.h include/sql_parser/keywords_pgsql.h +git commit -m "feat: add token types and keyword lookup tables for MySQL and PostgreSQL" +``` + +--- + +### Task 6: Tokenizer + +**Files:** +- Create: `include/sql_parser/tokenizer.h` +- Create: `tests/test_tokenizer.cpp` + +- [ ] **Step 1: Write the failing test** + +Create `tests/test_tokenizer.cpp`: +```cpp +#include +#include "sql_parser/tokenizer.h" + +using namespace sql_parser; + +// ========== MySQL Tokenizer Tests ========== + +class MySQLTokenizerTest : public ::testing::Test { +protected: + Tokenizer tok; +}; + +TEST_F(MySQLTokenizerTest, SimpleSelect) { + const char* sql = "SELECT * FROM users;"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_SELECT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_ASTERISK); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_FROM); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(t.text.len, 5u); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "users"); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_SEMICOLON); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, CaseInsensitiveKeywords) { + const char* sql = "select FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, BacktickIdentifier) { + const char* sql = "`my table`"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "my table"); +} + +TEST_F(MySQLTokenizerTest, SingleQuotedString) { + const char* sql = "'hello world'"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_STRING); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "hello world"); +} + +TEST_F(MySQLTokenizerTest, IntegerLiteral) { + const char* sql = "42"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_INTEGER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "42"); +} + +TEST_F(MySQLTokenizerTest, FloatLiteral) { + const char* sql = "3.14"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_FLOAT); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "3.14"); +} + +TEST_F(MySQLTokenizerTest, ComparisonOperators) { + const char* sql = "= != < > <= >="; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_NOT_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_LESS); + EXPECT_EQ(tok.next_token().type, TokenType::TK_GREATER); + EXPECT_EQ(tok.next_token().type, TokenType::TK_LESS_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_GREATER_EQUAL); +} + +TEST_F(MySQLTokenizerTest, DiamondNotEqual) { + const char* sql = "<>"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_NOT_EQUAL); +} + +TEST_F(MySQLTokenizerTest, AtVariables) { + const char* sql = "@myvar @@global_var"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_AT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOUBLE_AT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); +} + +TEST_F(MySQLTokenizerTest, Placeholder) { + const char* sql = "?"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_QUESTION); +} + +TEST_F(MySQLTokenizerTest, ColonEqual) { + const char* sql = ":="; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_COLON_EQUAL); +} + +TEST_F(MySQLTokenizerTest, LineComment) { + const char* sql = "SELECT -- this is a comment\nFROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, HashComment) { + const char* sql = "SELECT # comment\nFROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, BlockComment) { + const char* sql = "SELECT /* comment */ FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, PeekDoesNotConsume) { + const char* sql = "SELECT FROM"; + tok.reset(sql, strlen(sql)); + + Token peeked = tok.peek(); + EXPECT_EQ(peeked.type, TokenType::TK_SELECT); + + Token consumed = tok.next_token(); + EXPECT_EQ(consumed.type, TokenType::TK_SELECT); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, EmptyInput) { + tok.reset("", 0); + EXPECT_EQ(tok.next_token().type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, WhitespaceOnly) { + const char* sql = " \t\n\r "; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, QualifiedIdentifier) { + const char* sql = "myschema.orders"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_IDENTIFIER); // myschema + EXPECT_EQ(tok.next_token().type, TokenType::TK_DOT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_IDENTIFIER); // orders +} + +// ========== PostgreSQL Tokenizer Tests ========== + +class PgSQLTokenizerTest : public ::testing::Test { +protected: + Tokenizer tok; +}; + +TEST_F(PgSQLTokenizerTest, DoubleQuotedIdentifier) { + const char* sql = "\"my table\""; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "my table"); +} + +TEST_F(PgSQLTokenizerTest, DollarQuotedString) { + const char* sql = "$$hello world$$"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_STRING); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "hello world"); +} + +TEST_F(PgSQLTokenizerTest, DoubleColonCast) { + const char* sql = "::"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_DOUBLE_COLON); +} + +TEST_F(PgSQLTokenizerTest, PositionalParam) { + const char* sql = "$1 $23"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOLLAR_NUM); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "$1"); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOLLAR_NUM); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "$23"); +} + +TEST_F(PgSQLTokenizerTest, NestedBlockComment) { + const char* sql = "SELECT /* outer /* inner */ still comment */ FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(PgSQLTokenizerTest, NoHashComment) { + // PostgreSQL does NOT support # comments — # should be TK_HASH token + const char* sql = "#"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_HASH); +} +``` + +- [ ] **Step 2: Write tokenizer.h** + +Create `include/sql_parser/tokenizer.h`. This is header-only for max inlining: + +```cpp +#ifndef SQL_PARSER_TOKENIZER_H +#define SQL_PARSER_TOKENIZER_H + +#include "sql_parser/token.h" +#include "sql_parser/keywords_mysql.h" +#include "sql_parser/keywords_pgsql.h" + +namespace sql_parser { + +template +class Tokenizer { +public: + void reset(const char* input, size_t len) { + start_ = input; + cursor_ = input; + end_ = input + len; + has_peeked_ = false; + } + + Token next_token() { + if (has_peeked_) { + has_peeked_ = false; + return peeked_; + } + return scan_token(); + } + + Token peek() { + if (!has_peeked_) { + peeked_ = scan_token(); + has_peeked_ = true; + } + return peeked_; + } + + void skip() { + if (has_peeked_) { + has_peeked_ = false; + } else { + scan_token(); + } + } + + // Expose end of input for remaining-input calculation + const char* input_end() const { return end_; } + +private: + const char* start_ = nullptr; + const char* cursor_ = nullptr; + const char* end_ = nullptr; + Token peeked_; + bool has_peeked_ = false; + + uint32_t offset() const { + return static_cast(cursor_ - start_); + } + + char current() const { return (cursor_ < end_) ? *cursor_ : '\0'; } + char advance() { + char c = current(); + if (cursor_ < end_) ++cursor_; + return c; + } + char peek_char(size_t ahead = 0) const { + const char* p = cursor_ + ahead; + return (p < end_) ? *p : '\0'; + } + + void skip_whitespace_and_comments() { + while (cursor_ < end_) { + char c = *cursor_; + + // Whitespace + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + ++cursor_; + continue; + } + + // -- line comment (MySQL requires space after --, PgSQL doesn't but we handle both) + if (c == '-' && peek_char(1) == '-') { + cursor_ += 2; + while (cursor_ < end_ && *cursor_ != '\n') ++cursor_; + continue; + } + + // # line comment (MySQL only) + if constexpr (D == Dialect::MySQL) { + if (c == '#') { + ++cursor_; + while (cursor_ < end_ && *cursor_ != '\n') ++cursor_; + continue; + } + } + + // /* block comment */ + if (c == '/' && peek_char(1) == '*') { + cursor_ += 2; + if constexpr (D == Dialect::PostgreSQL) { + // PostgreSQL supports nested block comments + int depth = 1; + while (cursor_ < end_ && depth > 0) { + if (*cursor_ == '/' && peek_char(1) == '*') { + ++depth; + cursor_ += 2; + } else if (*cursor_ == '*' && peek_char(1) == '/') { + --depth; + cursor_ += 2; + } else { + ++cursor_; + } + } + } else { + // MySQL: no nesting + while (cursor_ < end_) { + if (*cursor_ == '*' && peek_char(1) == '/') { + cursor_ += 2; + break; + } + ++cursor_; + } + } + continue; + } + + break; // not whitespace or comment + } + } + + Token make_token(TokenType type, const char* start, uint32_t len) { + return Token{type, StringRef{start, len}, + static_cast(start - start_)}; + } + + Token scan_identifier_or_keyword() { + const char* start = cursor_; + while (cursor_ < end_) { + char c = *cursor_; + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_') { + ++cursor_; + } else { + break; + } + } + uint32_t len = static_cast(cursor_ - start); + + // Keyword lookup + TokenType kw; + if constexpr (D == Dialect::MySQL) { + kw = mysql_keywords::lookup(start, len); + } else { + kw = pgsql_keywords::lookup(start, len); + } + return make_token(kw, start, len); + } + + Token scan_number() { + const char* start = cursor_; + bool has_dot = false; + while (cursor_ < end_) { + char c = *cursor_; + if (c >= '0' && c <= '9') { + ++cursor_; + } else if (c == '.' && !has_dot) { + has_dot = true; + ++cursor_; + } else { + break; + } + } + uint32_t len = static_cast(cursor_ - start); + return make_token(has_dot ? TokenType::TK_FLOAT : TokenType::TK_INTEGER, + start, len); + } + + Token scan_single_quoted_string() { + ++cursor_; // skip opening quote + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '\'') { + if (*cursor_ == '\\') { + ++cursor_; // skip escaped char + if (cursor_ < end_) ++cursor_; + } else { + ++cursor_; + } + } + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing quote + return make_token(TokenType::TK_STRING, content_start, len); + } + + // MySQL: backtick-quoted identifier + Token scan_backtick_identifier() { + ++cursor_; // skip opening backtick + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '`') ++cursor_; + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing backtick + return make_token(TokenType::TK_IDENTIFIER, content_start, len); + } + + // PostgreSQL: double-quoted identifier + Token scan_double_quoted_identifier() { + ++cursor_; // skip opening quote + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '"') ++cursor_; + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing quote + return make_token(TokenType::TK_IDENTIFIER, content_start, len); + } + + // PostgreSQL: $$...$$ dollar-quoted string + Token scan_dollar_string() { + // We're at the first $. Simple form: $$content$$ + cursor_ += 2; // skip opening $$ + const char* content_start = cursor_; + while (cursor_ < end_) { + if (*cursor_ == '$' && peek_char(1) == '$') { + uint32_t len = static_cast(cursor_ - content_start); + cursor_ += 2; // skip closing $$ + return make_token(TokenType::TK_STRING, content_start, len); + } + ++cursor_; + } + // Unterminated — return what we have + uint32_t len = static_cast(cursor_ - content_start); + return make_token(TokenType::TK_STRING, content_start, len); + } + + Token scan_token() { + skip_whitespace_and_comments(); + + if (cursor_ >= end_) { + return make_token(TokenType::TK_EOF, cursor_, 0); + } + + char c = *cursor_; + + // Identifiers and keywords + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return scan_identifier_or_keyword(); + } + + // Numbers + if (c >= '0' && c <= '9') { + return scan_number(); + } + + // Dot — could be start of .123 float or just dot + if (c == '.' && cursor_ + 1 < end_ && + peek_char(1) >= '0' && peek_char(1) <= '9') { + return scan_number(); + } + + // String literals + if (c == '\'') return scan_single_quoted_string(); + + // MySQL: double-quoted strings; PostgreSQL: double-quoted identifiers + if (c == '"') { + if constexpr (D == Dialect::MySQL) { + // In MySQL, double quotes are strings (unless ANSI_QUOTES mode) + ++cursor_; + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '"') { + if (*cursor_ == '\\') { ++cursor_; if (cursor_ < end_) ++cursor_; } + else ++cursor_; + } + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; + return make_token(TokenType::TK_STRING, content_start, len); + } else { + return scan_double_quoted_identifier(); + } + } + + // Backtick identifier (MySQL only) + if constexpr (D == Dialect::MySQL) { + if (c == '`') return scan_backtick_identifier(); + } + + // @ and @@ + if (c == '@') { + if (peek_char(1) == '@') { + const char* s = cursor_; + cursor_ += 2; + return make_token(TokenType::TK_DOUBLE_AT, s, 2); + } + const char* s = cursor_; + ++cursor_; + return make_token(TokenType::TK_AT, s, 1); + } + + // $ — PostgreSQL: $N placeholder or $$string$$ + if constexpr (D == Dialect::PostgreSQL) { + if (c == '$') { + if (peek_char(1) == '$') { + return scan_dollar_string(); + } + if (peek_char(1) >= '0' && peek_char(1) <= '9') { + const char* start = cursor_; + ++cursor_; // skip $ + while (cursor_ < end_ && *cursor_ >= '0' && *cursor_ <= '9') + ++cursor_; + uint32_t len = static_cast(cursor_ - start); + return make_token(TokenType::TK_DOLLAR_NUM, start, len); + } + } + } + + // Two-character operators + if (cursor_ + 1 < end_) { + char c2 = peek_char(1); + + if (c == '<' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_LESS_EQUAL, s, 2); } + if (c == '>' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_GREATER_EQUAL, s, 2); } + if (c == '!' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_NOT_EQUAL, s, 2); } + if (c == '<' && c2 == '>') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_NOT_EQUAL, s, 2); } + if (c == '|' && c2 == '|') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_DOUBLE_PIPE, s, 2); } + + if constexpr (D == Dialect::MySQL) { + if (c == ':' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_COLON_EQUAL, s, 2); } + } + + if constexpr (D == Dialect::PostgreSQL) { + if (c == ':' && c2 == ':') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_DOUBLE_COLON, s, 2); } + } + } + + // Single-character operators/punctuation + const char* s = cursor_; + ++cursor_; + switch (c) { + case '(': return make_token(TokenType::TK_LPAREN, s, 1); + case ')': return make_token(TokenType::TK_RPAREN, s, 1); + case ',': return make_token(TokenType::TK_COMMA, s, 1); + case ';': return make_token(TokenType::TK_SEMICOLON, s, 1); + case '.': return make_token(TokenType::TK_DOT, s, 1); + case '*': return make_token(TokenType::TK_ASTERISK, s, 1); + case '+': return make_token(TokenType::TK_PLUS, s, 1); + case '-': return make_token(TokenType::TK_MINUS, s, 1); + case '/': return make_token(TokenType::TK_SLASH, s, 1); + case '%': return make_token(TokenType::TK_PERCENT, s, 1); + case '=': return make_token(TokenType::TK_EQUAL, s, 1); + case '<': return make_token(TokenType::TK_LESS, s, 1); + case '>': return make_token(TokenType::TK_GREATER, s, 1); + case '&': return make_token(TokenType::TK_AMPERSAND, s, 1); + case '|': return make_token(TokenType::TK_PIPE, s, 1); + case '^': return make_token(TokenType::TK_CARET, s, 1); + case '~': return make_token(TokenType::TK_TILDE, s, 1); + case '!': return make_token(TokenType::TK_EXCLAIM, s, 1); + case ':': return make_token(TokenType::TK_COLON, s, 1); + case '?': return make_token(TokenType::TK_QUESTION, s, 1); + case '#': return make_token(TokenType::TK_HASH, s, 1); + default: return make_token(TokenType::TK_ERROR, s, 1); + } + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TOKENIZER_H +``` + +- [ ] **Step 3: Build and run tests** + +Run: +```bash +make -f Makefile.new test +``` +Expected: All MySQLTokenizerTest and PgSQLTokenizerTest tests PASS. + +- [ ] **Step 4: Commit** + +```bash +git add include/sql_parser/tokenizer.h tests/test_tokenizer.cpp +git commit -m "feat: add dialect-templated tokenizer with MySQL and PostgreSQL support" +``` + +--- + +### Task 7: Parser — Classifier and Tier 2 Extractors + +**Files:** +- Create: `include/sql_parser/parser.h` +- Create: `src/sql_parser/parser.cpp` +- Create: `tests/test_classifier.cpp` + +- [ ] **Step 1: Write the failing test** + +Create `tests/test_classifier.cpp`: +```cpp +#include +#include "sql_parser/parser.h" + +using namespace sql_parser; + +// ========== MySQL Classifier Tests ========== + +class MySQLClassifierTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(MySQLClassifierTest, ClassifySelect) { + auto r = parser.parse("SELECT * FROM users", 19); + // SELECT is Tier 1 — for now returns PARTIAL until deep parser is implemented + EXPECT_EQ(r.stmt_type, StmtType::SELECT); +} + +TEST_F(MySQLClassifierTest, ClassifyInsert) { + auto r = parser.parse("INSERT INTO users VALUES (1, 'a')", 33); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyInsertQualified) { + auto r = parser.parse("INSERT INTO mydb.users VALUES (1)", 33); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.schema_name.ptr, r.schema_name.len), "mydb"); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyUpdate) { + auto r = parser.parse("UPDATE users SET name='x'", 25); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyDelete) { + auto r = parser.parse("DELETE FROM users WHERE id=1", 28); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifySet) { + auto r = parser.parse("SET autocommit=0", 16); + EXPECT_EQ(r.stmt_type, StmtType::SET); +} + +TEST_F(MySQLClassifierTest, ClassifyUse) { + auto r = parser.parse("USE mydb", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::USE); + EXPECT_EQ(std::string(r.database_name.ptr, r.database_name.len), "mydb"); +} + +TEST_F(MySQLClassifierTest, ClassifyBegin) { + auto r = parser.parse("BEGIN", 5); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::BEGIN); +} + +TEST_F(MySQLClassifierTest, ClassifyStartTransaction) { + auto r = parser.parse("START TRANSACTION", 17); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::START_TRANSACTION); +} + +TEST_F(MySQLClassifierTest, ClassifyCommit) { + auto r = parser.parse("COMMIT", 6); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::COMMIT); +} + +TEST_F(MySQLClassifierTest, ClassifyRollback) { + auto r = parser.parse("ROLLBACK", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::ROLLBACK); +} + +TEST_F(MySQLClassifierTest, ClassifyCreateTable) { + auto r = parser.parse("CREATE TABLE users (id INT)", 27); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::CREATE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyDropTable) { + auto r = parser.parse("DROP TABLE users", 16); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DROP); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyShow) { + auto r = parser.parse("SHOW TABLES", 11); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SHOW); +} + +TEST_F(MySQLClassifierTest, ClassifyReplace) { + auto r = parser.parse("REPLACE INTO users VALUES (1)", 29); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyGrant) { + auto r = parser.parse("GRANT SELECT ON users TO 'app'", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::GRANT); +} + +TEST_F(MySQLClassifierTest, ClassifyRevoke) { + auto r = parser.parse("REVOKE ALL ON users FROM 'app'", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REVOKE); +} + +TEST_F(MySQLClassifierTest, ClassifyLock) { + auto r = parser.parse("LOCK TABLES users WRITE", 23); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::LOCK); +} + +TEST_F(MySQLClassifierTest, ClassifyDeallocate) { + auto r = parser.parse("DEALLOCATE PREPARE stmt1", 24); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DEALLOCATE); +} + +TEST_F(MySQLClassifierTest, ClassifyUnknown) { + auto r = parser.parse("EXPLAIN SELECT 1", 16); + EXPECT_EQ(r.stmt_type, StmtType::UNKNOWN); +} + +TEST_F(MySQLClassifierTest, EmptyInput) { + auto r = parser.parse("", 0); + EXPECT_EQ(r.status, ParseResult::ERROR); + EXPECT_EQ(r.stmt_type, StmtType::UNKNOWN); +} + +TEST_F(MySQLClassifierTest, MultiStatement) { + const char* sql = "BEGIN; SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.stmt_type, StmtType::BEGIN); + EXPECT_TRUE(r.has_remaining()); + // remaining should point to " SELECT 1" + EXPECT_GT(r.remaining.len, 0u); +} + +TEST_F(MySQLClassifierTest, CaseInsensitive) { + auto r = parser.parse("insert into USERS values (1)", 28); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "USERS"); +} + +// ========== PostgreSQL Classifier Tests ========== + +class PgSQLClassifierTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLClassifierTest, ClassifySelect) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); +} + +TEST_F(PgSQLClassifierTest, ClassifyInsert) { + auto r = parser.parse("INSERT INTO users VALUES (1)", 28); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(PgSQLClassifierTest, ClassifyReset) { + auto r = parser.parse("RESET ALL", 9); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::RESET); +} +``` + +- [ ] **Step 2: Write parser.h** + +Create `include/sql_parser/parser.h`: +```cpp +#ifndef SQL_PARSER_PARSER_H +#define SQL_PARSER_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/parse_result.h" + +namespace sql_parser { + +struct ParserConfig { + size_t arena_block_size = 65536; // 64KB + size_t arena_max_size = 1048576; // 1MB +}; + +template +class Parser { +public: + explicit Parser(const ParserConfig& config = {}); + ~Parser() = default; + + // Non-copyable, non-movable + Parser(const Parser&) = delete; + Parser& operator=(const Parser&) = delete; + + // Parse a SQL string. Returns ParseResult with classification + metadata. + // For Tier 1 statements (SELECT, SET), returns PARTIAL until deep parsers + // are implemented (future plan). + ParseResult parse(const char* sql, size_t len); + + // Reset the arena. Call after each query is fully processed. + void reset(); + +private: + Arena arena_; + Tokenizer tokenizer_; + + // Classifier: dispatches to the right extractor/parser + ParseResult classify_and_dispatch(); + + // Tier 1 stubs (return PARTIAL with stmt_type set) + ParseResult parse_select(); + ParseResult parse_set(); + + // Tier 2 extractors + ParseResult extract_insert(const Token& first); + ParseResult extract_update(const Token& first); + ParseResult extract_delete(const Token& first); + ParseResult extract_replace(const Token& first); + ParseResult extract_transaction(const Token& first); + ParseResult extract_use(const Token& first); + ParseResult extract_show(const Token& first); + ParseResult extract_prepare(const Token& first); + ParseResult extract_execute(const Token& first); + ParseResult extract_deallocate(const Token& first); + ParseResult extract_ddl(const Token& first); + ParseResult extract_acl(const Token& first); + ParseResult extract_lock(const Token& first); + ParseResult extract_load(const Token& first); + ParseResult extract_reset(const Token& first); + ParseResult extract_unknown(const Token& first); + + // Helpers + // Read optional schema.table or just table. Returns table token. + // If qualified (schema.table), sets schema_out. + Token read_table_name(StringRef& schema_out); + + // Scan forward to semicolon or EOF, set result.remaining + void scan_to_end(ParseResult& result); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_PARSER_H +``` + +- [ ] **Step 3: Write parser.cpp** + +Create `src/sql_parser/parser.cpp`: +```cpp +#include "sql_parser/parser.h" + +namespace sql_parser { + +template +Parser::Parser(const ParserConfig& config) + : arena_(config.arena_block_size, config.arena_max_size) {} + +template +void Parser::reset() { + arena_.reset(); +} + +template +ParseResult Parser::parse(const char* sql, size_t len) { + arena_.reset(); + tokenizer_.reset(sql, len); + return classify_and_dispatch(); +} + +template +ParseResult Parser::classify_and_dispatch() { + Token first = tokenizer_.next_token(); + + if (first.type == TokenType::TK_EOF) { + ParseResult r; + r.status = ParseResult::ERROR; + r.stmt_type = StmtType::UNKNOWN; + return r; + } + + switch (first.type) { + case TokenType::TK_SELECT: return parse_select(); + case TokenType::TK_SET: return parse_set(); + case TokenType::TK_INSERT: return extract_insert(first); + case TokenType::TK_UPDATE: return extract_update(first); + case TokenType::TK_DELETE: return extract_delete(first); + case TokenType::TK_REPLACE: return extract_replace(first); + case TokenType::TK_BEGIN: + case TokenType::TK_START: + case TokenType::TK_COMMIT: + case TokenType::TK_ROLLBACK: + case TokenType::TK_SAVEPOINT:return extract_transaction(first); + case TokenType::TK_USE: return extract_use(first); + case TokenType::TK_SHOW: return extract_show(first); + case TokenType::TK_PREPARE: return extract_prepare(first); + case TokenType::TK_EXECUTE: return extract_execute(first); + case TokenType::TK_DEALLOCATE: return extract_deallocate(first); + case TokenType::TK_CREATE: + case TokenType::TK_ALTER: + case TokenType::TK_DROP: + case TokenType::TK_TRUNCATE: return extract_ddl(first); + case TokenType::TK_GRANT: + case TokenType::TK_REVOKE: return extract_acl(first); + case TokenType::TK_LOCK: + case TokenType::TK_UNLOCK: return extract_lock(first); + case TokenType::TK_LOAD: return extract_load(first); + case TokenType::TK_RESET: return extract_reset(first); + default: return extract_unknown(first); + } +} + +// ---- Tier 1 stubs ---- + +template +ParseResult Parser::parse_select() { + ParseResult r; + r.status = ParseResult::PARTIAL; + r.stmt_type = StmtType::SELECT; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_set() { + ParseResult r; + r.status = ParseResult::PARTIAL; + r.stmt_type = StmtType::SET; + scan_to_end(r); + return r; +} + +// ---- Helpers ---- + +template +Token Parser::read_table_name(StringRef& schema_out) { + Token name = tokenizer_.next_token(); + if (name.type != TokenType::TK_IDENTIFIER && + name.type != TokenType::TK_EOF) { + // Keywords used as table names (e.g., CREATE TABLE `user`) + // The tokenizer returns keyword tokens for reserved words. + // Accept any non-punctuation token as a potential name. + } + + // Check for qualified name: schema.table + if (tokenizer_.peek().type == TokenType::TK_DOT) { + schema_out = name.text; + tokenizer_.skip(); // consume dot + Token table = tokenizer_.next_token(); + return table; + } + + schema_out = StringRef{}; + return name; +} + +template +void Parser::scan_to_end(ParseResult& result) { + while (true) { + Token t = tokenizer_.next_token(); + if (t.type == TokenType::TK_EOF) break; + if (t.type == TokenType::TK_SEMICOLON) { + Token next = tokenizer_.peek(); + if (next.type != TokenType::TK_EOF) { + const char* remaining_start = next.text.ptr; + const char* input_end = tokenizer_.input_end(); + result.remaining = StringRef{ + remaining_start, + static_cast(input_end - remaining_start) + }; + } + break; + } + } +} + +// ---- Tier 2 Extractors ---- + +template +ParseResult Parser::extract_insert(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::INSERT; + + // Expect optional INTO + Token t = tokenizer_.peek(); + if (t.type == TokenType::TK_INTO) { + tokenizer_.skip(); + } + + // Read table name + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_update(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::UPDATE; + + // Optional LOW_PRIORITY / IGNORE + Token t = tokenizer_.peek(); + while (t.type == TokenType::TK_LOW_PRIORITY || t.type == TokenType::TK_IGNORE) { + tokenizer_.skip(); + t = tokenizer_.peek(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_delete(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::DELETE_STMT; + + // Optional LOW_PRIORITY / QUICK / IGNORE + Token t = tokenizer_.peek(); + while (t.type == TokenType::TK_LOW_PRIORITY || + t.type == TokenType::TK_QUICK || + t.type == TokenType::TK_IGNORE) { + tokenizer_.skip(); + t = tokenizer_.peek(); + } + + // Expect FROM + if (tokenizer_.peek().type == TokenType::TK_FROM) { + tokenizer_.skip(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_replace(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::REPLACE; + + if (tokenizer_.peek().type == TokenType::TK_INTO) { + tokenizer_.skip(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_transaction(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + + switch (first.type) { + case TokenType::TK_BEGIN: + r.stmt_type = StmtType::BEGIN; + break; + case TokenType::TK_START: + r.stmt_type = StmtType::START_TRANSACTION; + // consume TRANSACTION if present + if (tokenizer_.peek().type == TokenType::TK_TRANSACTION) + tokenizer_.skip(); + break; + case TokenType::TK_COMMIT: + r.stmt_type = StmtType::COMMIT; + break; + case TokenType::TK_ROLLBACK: + r.stmt_type = StmtType::ROLLBACK; + break; + case TokenType::TK_SAVEPOINT: + r.stmt_type = StmtType::SAVEPOINT; + break; + default: + r.stmt_type = StmtType::UNKNOWN; + break; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_use(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::USE; + + Token db = tokenizer_.next_token(); + r.database_name = db.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_show(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::SHOW; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_prepare(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::PREPARE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_execute(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::EXECUTE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_deallocate(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::DEALLOCATE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_ddl(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + + switch (first.type) { + case TokenType::TK_CREATE: r.stmt_type = StmtType::CREATE; break; + case TokenType::TK_ALTER: r.stmt_type = StmtType::ALTER; break; + case TokenType::TK_DROP: r.stmt_type = StmtType::DROP; break; + case TokenType::TK_TRUNCATE: r.stmt_type = StmtType::TRUNCATE; break; + default: r.stmt_type = StmtType::UNKNOWN; break; + } + + // Try to extract object name: CREATE/ALTER/DROP [IF EXISTS/NOT EXISTS] TABLE/INDEX/VIEW name + Token t = tokenizer_.next_token(); + + // Skip optional IF [NOT] EXISTS + if (t.type == TokenType::TK_IF) { + t = tokenizer_.next_token(); // NOT or EXISTS + if (t.type == TokenType::TK_NOT) { + t = tokenizer_.next_token(); // EXISTS + } + // Skip EXISTS + t = tokenizer_.next_token(); // should be TABLE/INDEX/etc. + } + + // Now t should be TABLE, INDEX, VIEW, DATABASE, SCHEMA, or a name + if (t.type == TokenType::TK_TABLE || t.type == TokenType::TK_INDEX || + t.type == TokenType::TK_VIEW || t.type == TokenType::TK_DATABASE || + t.type == TokenType::TK_SCHEMA) { + // Optional IF [NOT] EXISTS after object type for CREATE/DROP + Token maybe_if = tokenizer_.peek(); + if (maybe_if.type == TokenType::TK_IF) { + tokenizer_.skip(); // IF + Token next = tokenizer_.next_token(); + if (next.type == TokenType::TK_NOT) { + tokenizer_.skip(); // EXISTS + } + } + Token name = read_table_name(r.schema_name); + r.table_name = name.text; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_acl(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = (first.type == TokenType::TK_GRANT) ? StmtType::GRANT : StmtType::REVOKE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_lock(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = (first.type == TokenType::TK_LOCK) ? StmtType::LOCK : StmtType::UNLOCK; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_load(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::LOAD_DATA; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_reset(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::RESET; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_unknown(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::UNKNOWN; + scan_to_end(r); + return r; +} + +// ---- Explicit template instantiations ---- + +template class Parser; +template class Parser; + +} // namespace sql_parser +``` + +- [ ] **Step 4: Build and run tests** + +Run: +```bash +make -f Makefile.new test +``` +Expected: All classifier and extractor tests PASS. + +- [ ] **Step 5: Commit** + +```bash +git add include/sql_parser/parser.h src/sql_parser/parser.cpp tests/test_classifier.cpp +git commit -m "feat: add classifier and Tier 2 extractors for all statement types" +``` + +--- + +### Task 8: Integration Smoke Test and .gitignore Update + +**Files:** +- Modify: `.gitignore` + +- [ ] **Step 1: Update .gitignore** + +Append to `.gitignore`: +``` +# New parser build artifacts +libsqlparser.a +run_tests +third_party/ +``` + +- [ ] **Step 2: Run full build from clean** + +Run: +```bash +make -f Makefile.new clean && make -f Makefile.new all +``` +Expected: Builds `libsqlparser.a`, runs all tests, all PASS. + +- [ ] **Step 3: Commit** + +```bash +git add .gitignore +git commit -m "chore: update .gitignore for new parser build artifacts" +``` + +--- + +## What's Next + +After this plan is complete, the following plans should be created (in order): + +1. **Plan 2: Expression Parser + SET Deep Parser** — Pratt expression parser, full SET statement parser with AST construction, round-trip tests. +2. **Plan 3: SELECT Deep Parser** — Full SELECT statement parsing with all clauses, using the expression parser from Plan 2. +3. **Plan 4: Query Emitter** — AST → SQL reconstruction, round-trip testing. +4. **Plan 5: Prepared Statement Cache** — Statement cache, binary protocol support, `parse_and_cache` / `execute` API. +5. **Plan 6: Performance Benchmarks + Optimization** — Google Benchmark integration, performance validation against targets, optimization passes (perfect hash for keywords, etc.). diff --git a/docs/superpowers/specs/2026-03-24-sql-parser-design.md b/docs/superpowers/specs/2026-03-24-sql-parser-design.md new file mode 100644 index 0000000..b60289a --- /dev/null +++ b/docs/superpowers/specs/2026-03-24-sql-parser-design.md @@ -0,0 +1,490 @@ +# SQL Parser for ProxySQL — Design Specification + +## Overview + +A high-performance, hand-written recursive descent SQL parser for ProxySQL. Supports both MySQL and PostgreSQL dialects. Designed for sub-microsecond latency on the proxy hot path. + +### Goals + +- **Two-tier parsing:** Deep parse (full AST + reconstruction) for SELECT and SET. Classify-and-extract for all other statement types. +- **Both dialects from the start:** MySQL and PostgreSQL via compile-time dialect templating. No runtime dispatch overhead. +- **Both protocols:** Text protocol (COM_QUERY) and binary protocol (COM_STMT_PREPARE / COM_STMT_EXECUTE). +- **Query reconstruction:** Parse → modify AST → emit valid SQL. With an option to use read-only inspection mode when reconstruction isn't needed. +- **Sub-microsecond latency:** Arena allocation, zero-copy string references, no exceptions, no virtual dispatch. +- **Thread-safe by isolation:** One parser instance per thread. No shared mutable state, no locks. + +### Constraints + +- C++17 floor. Optional C++20 features behind `#if __cplusplus >= 202002L` where they provide measurable performance gains. +- Must compile on AlmaLinux 8 (GCC 8.5) through Fedora 43 and macOS (Apple Silicon). +- Static library or header-only (whichever yields better performance after benchmarking). Linked into ProxySQL at build time. + +### Migration from Existing POC + +This project replaces the existing Flex/Bison-based POC parser wholesale. The old parser (Flex lexer, Bison LALR grammar, `std::string`-based AST with `std::vector` children, per-node heap allocation) is not carried forward. The existing `src/mysql_parser/`, `src/pgsql_parser/`, and `include/` directories will be removed once the new parser is functional. Existing examples in `examples/` serve as a test corpus for validating that the new parser handles the same queries correctly, then they too will be replaced. + +--- + +## Architecture + +Three-layer architecture: Tokenizer → Classifier → Statement Parsers. + +``` +Input SQL bytes + │ + ▼ +┌──────────────┐ +│ Tokenizer │ Zero-copy, dialect-templated, pull-based +│ │ Produces Token {type, StringRef, offset} +└──────┬───────┘ + │ + ▼ +┌──────────────┐ +│ Classifier │ 1-3 token lookahead, switch on first keyword +└──────┬───────┘ + │ + ├──── Tier 1 ──► Statement-specific deep parser (SELECT, SET) + │ │ + │ ▼ + │ Full AST in arena → ParseResult + │ + └──── Tier 2 ──► Lightweight extractor + │ + ▼ + StmtType + metadata → ParseResult +``` + +--- + +## Layer 1: Memory Model & Core Data Structures + +### Arena Allocator + +Each parser instance owns a thread-local arena — a pre-allocated memory block (64KB default). All AST nodes, materialized strings, and temporary data are allocated from the arena. After a query is fully processed, the arena resets (pointer rewind, O(1)). No per-node new/delete. + +**Growth strategy:** The arena uses block chaining — never `realloc` (which would invalidate all pointers). When the current block is exhausted, a new block is allocated and linked. `reset()` retains the first (primary) block and frees any overflow blocks. This means `reset()` is O(1) in the common case (single block) and O(n_overflow_blocks) in the rare case. A configurable maximum arena size (default: 1MB) prevents unbounded growth; exceeding it returns `ParseResult::ERROR`. + +``` +┌──────────────────────┐ ┌──────────────────────┐ +│ Block 1 (64KB) │───►│ Block 2 (overflow) │───► ... +│ [AstNode][string].. │ │ [AstNode][string].. │ +│ ^ │ │ ^ │ +│ cursor │ │ cursor │ +│ │ │ │ +│ reset() → cursor=0, │ │ (freed on reset) │ +│ free overflow blocks│ │ │ +└──────────────────────┘ └──────────────────────┘ +``` + +### StringRef (Zero-Copy) + +```cpp +struct StringRef { + const char* ptr; + uint32_t len; + + // Comparison, hashing helpers +}; +static_assert(std::is_trivially_copyable_v); +``` + +`StringRef` must remain a trivial type (no constructors, destructors, or virtual functions) to be safely used in unions and to enable memcpy-based operations. + +Points into the original input buffer. No copies or allocations for identifiers, keywords, or literals. The input SQL string must outlive the parse result (natural in ProxySQL — the query buffer is session-owned). + +When a string must be materialized (e.g., unescaping a quoted literal), it is allocated from the arena. + +### AstNode + +Flat, compact struct. No virtual functions, no `std::vector`. Children use an intrusive linked list (first_child + next_sibling): + +```cpp +struct AstNode { + AstNode* first_child; // 8 bytes — first child (intrusive list) + AstNode* next_sibling; // 8 bytes — next sibling + const char* value_ptr; // 8 bytes — pointer into input (inlined StringRef) + uint32_t value_len; // 4 bytes — length + NodeType type; // 2 bytes — enum + uint16_t flags; // 2 bytes — dialect bits, tier, modifiers +}; +// 32 bytes per node on 64-bit — exactly half a cache line. +// Fields ordered to avoid padding: pointers first, then 4-byte, then 2-byte. +static_assert(sizeof(AstNode) == 32); +``` + +### ParseResult + +```cpp +struct ParseResult { + enum Status { OK, PARTIAL, ERROR }; + Status status; + StmtType stmt_type; // always set, even on error (best-effort) + AstNode* ast; // non-null for Tier 1 OK + ErrorInfo error; // populated on ERROR/PARTIAL + StringRef remaining; // unparsed input after semicolon (for multi-statement) + + // Tier 2 extracted metadata + StringRef table_name; + StringRef schema_name; + // etc. +}; + +struct ErrorInfo { + uint32_t offset; // byte position in input + StringRef message; // arena-allocated +}; +``` + +`PARTIAL` semantics by tier: +- **Tier 1:** classifier succeeded (statement type known) but the deep parser hit a syntax error. The AST may be partially populated. ProxySQL can still route on statement type — let the backend report the error. +- **Tier 2:** classifier succeeded but the extractor could not find expected metadata (e.g., `INSERT INTO` with no table name following). `stmt_type` is set but metadata fields may be empty. +- `ERROR` means the first token could not be classified at all (e.g., binary garbage or empty input). + +**Lifetime note:** `ErrorInfo::message` points to arena-allocated memory. It becomes invalid after `parser.reset()`. Consumers must copy the message if they need it beyond the parse lifecycle. + +--- + +## Layer 2: Tokenizer + +Pull-based iterator. The parser calls `next_token()` which advances the cursor and returns a `Token`. No allocation, no token stream buffering. + +### Token + +```cpp +struct Token { + TokenType type; // enum: keyword, identifier, number, string, operator, etc. + StringRef text; // points into original input buffer + uint32_t offset; // byte position (for error reporting) +}; +``` + +### Dialect Templating + +```cpp +template +class Tokenizer { + const char* cursor_; + const char* end_; + Token peeked_; // one-token lookahead cache + bool has_peeked_; +public: + void reset(const char* input, size_t len); + Token next_token(); + Token peek(); + void skip(); +}; +``` + +Dialect-specific behavior resolved at compile time via `if constexpr`: + +| Feature | MySQL | PostgreSQL | +|---|---|---| +| Quoted identifiers | `` `backtick` `` | `"double quote"` | +| String literals | `'single'` or `"double"` | `'single'`, `$$dollar$$` | +| Comments | `-- `, `#`, `/* */` | `-- `, `/* */` (nested) | +| Operators | `:=` assignment | `::` cast, `~` regex | +| Placeholders | `?` | `$1`, `$2`, ... | + +### Keyword Recognition + +Perfect hash function (generated at build time or constexpr-computed). SQL has a bounded keyword set per dialect (~200-300 keywords). Single array lookup, O(1). + +```cpp +struct KeywordEntry { + StringRef text; + TokenType token; + uint16_t flags; // reserved vs non-reserved, dialect mask +}; +``` + +### Lookahead + +The tokenizer provides single-token lookahead via `peek()`, cached internally. Statement parsers that need multi-token disambiguation (e.g., `SET TRANSACTION` vs `SET var = ...`, or `INSERT INTO ... SELECT` vs `INSERT INTO ... VALUES`) handle this by consuming tokens and using the parser's own state to disambiguate — no backtracking needed. For example, the SET parser consumes the second token; if it's `TRANSACTION`, it enters `parse_set_transaction()`, otherwise it treats the consumed token as the start of a variable target. This is standard recursive descent practice and does not require a multi-token lookahead buffer in the tokenizer itself. + +--- + +## Layer 3: Statement Classifier & Router + +Consumes 1-3 tokens to identify statement type. Switch on first token's enum value (jump table). + +```cpp +template +ParseResult Parser::parse(const char* sql, size_t len) { + tokenizer_.reset(sql, len); + Token first = tokenizer_.next_token(); + + switch (first.type) { + case TK_SELECT: return parse_select(); + case TK_SET: return parse_set(); + case TK_INSERT: return extract_insert(first); + case TK_UPDATE: return extract_update(first); + case TK_DELETE: return extract_delete(first); + case TK_BEGIN: + case TK_START: return extract_transaction(first); + case TK_COMMIT: + case TK_ROLLBACK:return extract_transaction(first); + case TK_SHOW: return extract_show(first); + case TK_USE: return extract_use(first); + case TK_PREPARE: return extract_prepare(first); + case TK_EXECUTE: return extract_execute(first); + case TK_CREATE: + case TK_ALTER: + case TK_DROP: + case TK_TRUNCATE:return extract_ddl(first); + case TK_GRANT: + case TK_REVOKE: return extract_acl(first); + default: return extract_unknown(first); + } +} +``` + +### Tier 2 Extractors + +Lightweight — scan forward to extract key pieces without building a full AST. Example: INSERT extractor reads `INTO`, then table name (1-2 tokens for qualified name), returns `ParseResult{stmt_type=INSERT, table=StringRef}`. + +### Promotion Path + +Promoting a Tier 2 statement to Tier 1: replace the extractor function with a recursive descent parser module. Classifier and everything else unchanged. + +--- + +## Tier 1 Deep Parsers + +### SELECT Parser + +``` +parse_select() + ├── parse_select_options() // DISTINCT, ALL, SQL_CALC_FOUND_ROWS (MySQL) + ├── parse_select_item_list() // expressions, aliases, * + │ └── parse_expression() // Pratt parser: operators, subqueries, functions + ├── parse_from_clause() + │ ├── parse_table_ref() + │ └── parse_join() // JOIN type, ON/USING + ├── parse_where_clause() + │ └── parse_expression() + ├── parse_group_by() + ├── parse_having() + ├── parse_order_by() + ├── parse_limit() // LIMIT/OFFSET (MySQL) vs LIMIT/FETCH (PgSQL) + ├── parse_locking() // FOR UPDATE/SHARE (dialect-specific) + └── parse_into() // INTO OUTFILE/DUMPFILE (MySQL only) +``` + +**Expression parsing** uses Pratt parsing (precedence climbing). A single `parse_expression(min_precedence)` handles unary, binary, ternary (BETWEEN), IS [NOT] NULL, IN (...), CASE/WHEN, subqueries, and function calls. + +**Dialect branching** via `if constexpr (D == Dialect::PostgreSQL)` for: +- `::` type cast +- `LIMIT ALL` vs `LIMIT` with expression +- Dollar-quoted strings + +Note: PostgreSQL's `RETURNING` clause applies to INSERT/UPDATE/DELETE, not SELECT. It will be handled when those statements are promoted to Tier 1. Until then, Tier 2 extractors for those statements will detect `RETURNING` and include it in metadata but not build an AST for the returned expressions. + +### SET Parser + +``` +parse_set() + ├── SET NAMES 'charset' [COLLATE 'collation'] + ├── SET CHARACTER SET 'charset' + ├── SET TRANSACTION [READ ONLY | READ WRITE | ISOLATION LEVEL ...] + └── SET [GLOBAL|SESSION|@@...] var = expr [, var = expr, ...] + ├── parse_variable_target() // @user_var, @@global.sys_var, plain name + └── parse_expression() + +PostgreSQL variants: + ├── SET name TO value + ├── SET name = value + ├── SET LOCAL name = value + └── RESET name / RESET ALL +``` + +The SET AST preserves full detail for query reconstruction — ProxySQL actively rewrites SET statements. + +--- + +## Query Reconstruction (Emitter) + +Each `NodeType` has a corresponding emit function. The emitter is dialect-templated, walking the AST and writing into an arena-backed output buffer. + +- `StringRef` values are emitted directly from the original input (no copy unless the node was modified). +- Modified nodes emit their new values. + +Cross-dialect emission (parse MySQL → emit PostgreSQL) is **out of scope** for the initial design. Many constructs have no direct equivalent across dialects (`SQL_CALC_FOUND_ROWS`, backtick quoting, `LIMIT` syntax differences). The emitter always emits in the same dialect it parsed. + +--- + +## Binary Protocol & Prepared Statements + +### Lifecycle + +``` +COM_STMT_PREPARE COM_STMT_EXECUTE (repeated) COM_STMT_CLOSE + │ │ │ + ▼ ▼ ▼ + Parse SQL template Bind params to cached AST Evict from cache + Cache AST + metadata Return enriched ParseResult +``` + +### Prepare Phase + +SQL template parsed normally. Placeholder tokens (`?` in MySQL, `$1`/`$2` in PostgreSQL) become `NODE_PLACEHOLDER` AST nodes with a parameter index in `flags`. + +The AST is copied from the arena to a longer-lived **statement cache** (per-parser-instance, keyed by statement ID) via `parse_and_cache()`, which atomically parses and stores the result before the arena can be reset. This is the one place where memory leaves the arena. + +**Threading note:** The statement cache is per-parser-instance (i.e., per-thread). In ProxySQL, prepared statement state is per-session. If sessions can migrate between threads, the session must carry its own prepared statement metadata (statement IDs, SQL templates). The parser on the destination thread can re-parse and cache the template on first execute if the cached AST is not found. This avoids any cross-thread sharing of parser state. + +### Execute Phase + +```cpp +struct BoundValue { + enum Type { INT, FLOAT, DOUBLE, STRING, BLOB, NULL_VAL, DATETIME, DECIMAL }; + Type type; + union { + int64_t int_val; + float float32_val; // MySQL FLOAT (4 bytes) — distinct from DOUBLE + double float64_val; // MySQL DOUBLE (8 bytes) + StringRef str_val; // points into COM_STMT_EXECUTE packet buffer + // also used for DATETIME/DECIMAL (wire-format string) + }; +}; +static_assert(std::is_trivially_copyable_v); + +struct ParamBindings { + BoundValue* values; + uint16_t count; +}; +``` + +Parser looks up cached AST by statement ID, returns `ParseResult` carrying both the AST pointer and `ParamBindings`. Consumers walk the AST and resolve placeholders through the bindings. + +### Reconstruction with Parameters + +The emitter checks placeholder nodes against `ParamBindings` and writes materialized values, producing valid text-protocol SQL from a binary-protocol execution. + +### Statement Cache + +Per-thread, fixed-capacity LRU. Size maps to ProxySQL's `max_prepared_statements` config. Eviction on COM_STMT_CLOSE or LRU overflow. + +--- + +## Error Handling + +The parser never throws exceptions. Errors are reported through `ParseResult::status` and `ParseResult::error`. + +- `OK` — parse succeeded fully. +- `PARTIAL` — classifier succeeded (statement type known), deep parse failed. ProxySQL can route on type; backend reports the syntax error. +- `ERROR` — could not classify at all. + +**Lenient by design:** ProxySQL doesn't need to reject queries — the backend does. The parser extracts as much useful information as possible and degrades gracefully. + +### Multi-Statement Queries + +ProxySQL regularly receives semicolon-separated multi-statement queries (e.g., `SET autocommit=0; BEGIN`). The parser handles this by parsing the **first statement** and returning its `ParseResult` along with a `remaining` field (`StringRef` pointing to the unparsed tail after the semicolon). The caller is responsible for calling `parse()` again on the remainder if needed. This avoids allocating a list of results and lets the caller decide whether to parse subsequent statements. + +### Maximum Query Length + +The parser respects the caller's buffer size (the `len` parameter). It does not impose its own maximum query length — that is ProxySQL's responsibility (via `mysql-max_allowed_packet` or equivalent). The arena's maximum size (default 1MB) provides an implicit bound on the complexity of parseable queries; exceeding it returns `ERROR`. + +--- + +## Public API + +```cpp +template +class Parser { +public: + Parser(const ParserConfig& config = {}); // arena size, cache capacity + + ParseResult parse(const char* sql, size_t len); + ParseResult parse_and_cache(const char* sql, size_t len, uint32_t stmt_id); + ParseResult execute(uint32_t stmt_id, const ParamBindings& params); + + void prepare_cache_evict(uint32_t stmt_id); + + void reset(); // resets arena; call after each query is fully processed +}; +``` + +One `Parser` or `Parser` per thread, created at thread startup. + +--- + +## Testing Strategy + +1. **Unit tests per module** — each statement parser and extractor gets its own test file. Feed SQL strings, assert on AST structure or extracted metadata. +2. **Round-trip tests** — parse → reconstruct → parse → compare ASTs. Run across a corpus of real-world queries. +3. **Performance benchmarks** (Google Benchmark or similar): + - Tier 2 classification latency (target: <100ns) + - Tier 1 full parse latency (target: <1us for typical SELECT/SET) + - Arena memory high-water mark per query + +--- + +## File Organization + +``` +include/sql_parser/ + common.h // StringRef, Arena, Token, enums + tokenizer.h // Tokenizer template + ast.h // AstNode, NodeType + parser.h // Parser public API + parse_result.h // ParseResult, ErrorInfo, BoundValue + keywords_mysql.h // MySQL keyword table + keywords_pgsql.h // PostgreSQL keyword table + +src/sql_parser/ + tokenizer.cpp // explicit template instantiations (or header-only with LTO for max inlining) + classifier.cpp // switch dispatch + select_parser.cpp // Tier 1: SELECT + set_parser.cpp // Tier 1: SET + expression_parser.cpp // shared Pratt expression parser + extractors.cpp // Tier 2: all lightweight extractors + emitter.cpp // query reconstruction + stmt_cache.cpp // prepared statement cache + arena.cpp // arena allocator + +tests/ + test_tokenizer.cpp + test_select.cpp + test_set.cpp + test_extractors.cpp + test_roundtrip.cpp + +bench/ + bench_classify.cpp + bench_select.cpp + bench_set.cpp +``` + +--- + +## Performance Targets + +| Operation | Target Latency | Notes | +|---|---|---| +| Tier 2 classification | <100ns | 1-3 token read + switch | +| Tier 1 SELECT parse (simple) | <500ns | `SELECT col FROM t WHERE id = 1` | +| Tier 1 SELECT parse (complex) | <2us | Multi-join, subqueries, GROUP BY, ORDER BY | +| Tier 1 SET parse | <300ns | `SET @@session.var = value` | +| Query reconstruction | <500ns | Simple SELECT round-trip | +| Arena reset | <10ns | Pointer rewind (single-block case; overflow blocks add O(n) free calls) | + +--- + +## Statement Tier Classification + +### Tier 1 — Deep Parse (full AST, reconstruction) + +- SELECT +- SET + +### Tier 2 — Classify + Extract Key Metadata + +- INSERT, UPDATE, DELETE, REPLACE +- USE, SHOW +- BEGIN, START TRANSACTION, COMMIT, ROLLBACK, SAVEPOINT +- PREPARE, EXECUTE, DEALLOCATE +- CREATE, ALTER, DROP, TRUNCATE +- GRANT, REVOKE +- LOCK, UNLOCK +- LOAD DATA +- All other statements (classified as UNKNOWN with raw text) diff --git a/docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md b/docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md new file mode 100644 index 0000000..e2d3273 --- /dev/null +++ b/docs/superpowers/specs/2026-03-24-tier1-promotions-and-digest-design.md @@ -0,0 +1,433 @@ +# Tier 1 Promotions, UNION Support & Query Digest — Design Specification + +## Overview + +Extends the SQL parser with full Tier 1 deep parsing for INSERT, UPDATE, and DELETE (both MySQL and PostgreSQL dialects), adds UNION/INTERSECT/EXCEPT compound query support with recursive nesting, and introduces AST-based query digest/normalization for query rules matching. + +### Goals + +- **INSERT Tier 1:** Full AST for INSERT/REPLACE with VALUES, SELECT, SET, ON DUPLICATE KEY UPDATE (MySQL), ON CONFLICT (PostgreSQL), RETURNING (PostgreSQL). +- **UPDATE Tier 1:** Full AST with multi-table JOIN (MySQL), FROM (PostgreSQL), ORDER BY/LIMIT (MySQL), RETURNING (PostgreSQL). +- **DELETE Tier 1:** Full AST with multi-table (MySQL both forms), USING (PostgreSQL), ORDER BY/LIMIT (MySQL), RETURNING (PostgreSQL). +- **Compound queries:** UNION [ALL], INTERSECT [ALL], EXCEPT [ALL] with parenthesized nesting and precedence (INTERSECT binds tighter). +- **Query digest:** AST-based normalization (literals → `?`, IN list collapsing, keyword uppercasing) + 64-bit hash. Works for all statement types including Tier 2 (token-level fallback). + +### Constraints + +- Same as original spec: C++17 floor, both dialects, sub-microsecond targets, arena allocation, header-only parsers. +- All new parsers follow the established pattern: `XxxParser` header-only template, uses `ExpressionParser`, integrated via `parser.cpp`. +- Emitter extended for all new node types + digest mode. + +### Prerequisite Refactoring: Shared Table Reference Parsing + +The existing `SelectParser` has `parse_from_clause()`, `parse_table_reference()`, `parse_join()`, and `parse_optional_alias()` as private methods. These are needed by `InsertParser` (for INSERT ... SELECT), `UpdateParser` (for MySQL multi-table UPDATE), and `DeleteParser` (for MySQL multi-table DELETE and PostgreSQL USING). + +**Solution:** Extract these methods into a shared `TableRefParser` utility class that takes a `Tokenizer&` and `Arena&`. All parsers (SelectParser, InsertParser, UpdateParser, DeleteParser) instantiate it internally. SelectParser's private methods are replaced with calls to TableRefParser. + +``` +include/sql_parser/ + table_ref_parser.h — shared FROM/JOIN/table reference parsing +``` + +This refactoring is a prerequisite for Plans 7-9 and should be done as the first task of Plan 7. + +### Classifier Updates + +The classifier switch in `Parser::classify_and_dispatch()` must be updated: + +- `TK_INSERT` → `parse_insert()` (was `extract_insert()`) +- `TK_UPDATE` → `parse_update()` (was `extract_update()`) +- `TK_DELETE` → `parse_delete()` (was `extract_delete()`) +- `TK_REPLACE` → `parse_insert()` with a REPLACE flag (was `extract_replace()`) +- `TK_SELECT` → compound query aware `parse_select()` (existing, updated) + +### is_alias_start() Update + +The `is_alias_start()` blocklist in SelectParser must be updated to include new clause-starting keywords: `TK_RETURNING`, `TK_INTERSECT`, `TK_EXCEPT`, `TK_CONFLICT`, `TK_DO`, `TK_NOTHING`, `TK_DUPLICATE`. + +--- + +## New NodeType Additions + +```cpp +// INSERT nodes +NODE_INSERT_STMT, +NODE_INSERT_COLUMNS, // (col1, col2, ...) +NODE_VALUES_CLAUSE, // VALUES keyword wrapper +NODE_VALUES_ROW, // single (val1, val2, ...) row +NODE_INSERT_SET_CLAUSE, // MySQL INSERT ... SET col=val form +NODE_ON_DUPLICATE_KEY, // MySQL ON DUPLICATE KEY UPDATE +NODE_ON_CONFLICT, // PostgreSQL ON CONFLICT +NODE_CONFLICT_TARGET, // PostgreSQL conflict target (cols or ON CONSTRAINT) +NODE_CONFLICT_ACTION, // DO UPDATE SET ... or DO NOTHING +NODE_RETURNING_CLAUSE, // PostgreSQL RETURNING expr_list + +// UPDATE nodes +NODE_UPDATE_STMT, +NODE_UPDATE_SET_CLAUSE, // SET col=expr, col=expr in UPDATE context +NODE_UPDATE_SET_ITEM, // single col=expr pair + +// DELETE nodes +NODE_DELETE_STMT, +NODE_DELETE_USING_CLAUSE, // PostgreSQL USING for join-like deletes + +// Compound query nodes +NODE_COMPOUND_QUERY, // root for UNION/INTERSECT/EXCEPT +NODE_SET_OPERATION, // operator (UNION, INTERSECT, EXCEPT) with ALL flag + +// Statement options (shared) +NODE_STMT_OPTIONS, // LOW_PRIORITY, IGNORE, QUICK, DELAYED, etc. +``` + +--- + +## INSERT Deep Parser + +### MySQL Syntax + +```sql +INSERT [LOW_PRIORITY | DELAYED | HIGH_PRIORITY] [IGNORE] [INTO] table_name + [(col1, col2, ...)] + { VALUES (row1), (row2), ... | SELECT ... | SET col=val, ... } + [ON DUPLICATE KEY UPDATE col=expr, col=expr, ...] + +REPLACE [LOW_PRIORITY | DELAYED] [INTO] table_name + [(col1, col2, ...)] + { VALUES (row1), (row2), ... | SELECT ... | SET col=val, ... } +``` + +### PostgreSQL Syntax + +```sql +INSERT INTO table_name [(col1, col2, ...)] + { VALUES (row1), (row2), ... | SELECT ... | DEFAULT VALUES } + [ON CONFLICT [(col1, col2, ...)] | [ON CONSTRAINT name] + { DO UPDATE SET col=expr [, ...] [WHERE ...] | DO NOTHING }] + [RETURNING expr_list] +``` + +### AST Structure + +``` +NODE_INSERT_STMT + ├── NODE_STMT_OPTIONS (LOW_PRIORITY, IGNORE, etc.) + ├── NODE_TABLE_REF (table name, optional schema) + ├── NODE_INSERT_COLUMNS (col1, col2, ...) + ├── NODE_VALUES_CLAUSE + │ ├── NODE_VALUES_ROW (val1, val2, ...) + │ └── NODE_VALUES_ROW (val1, val2, ...) + │ OR + ├── NODE_SELECT_STMT (INSERT ... SELECT) + │ OR + ├── NODE_INSERT_SET_CLAUSE (MySQL SET col=val) + │ ├── NODE_UPDATE_SET_ITEM (col = expr) + │ └── NODE_UPDATE_SET_ITEM (col = expr) + ├── NODE_ON_DUPLICATE_KEY (MySQL) + │ ├── NODE_UPDATE_SET_ITEM (col = expr) + │ └── NODE_UPDATE_SET_ITEM (col = expr) + │ OR + ├── NODE_ON_CONFLICT (PostgreSQL) + │ ├── NODE_CONFLICT_TARGET (cols or ON CONSTRAINT name) + │ └── NODE_CONFLICT_ACTION (DO UPDATE SET ... WHERE ... or DO NOTHING) + └── NODE_RETURNING_CLAUSE (PostgreSQL) + ├── expression + └── expression +``` + +The parser reuses `ExpressionParser` for all value expressions and `SelectParser` for `INSERT ... SELECT`. The `RETURNING` clause uses the same item-list parsing as SELECT's select item list. + +--- + +## UPDATE Deep Parser + +### MySQL Syntax + +```sql +UPDATE [LOW_PRIORITY] [IGNORE] table_references + SET col=expr [, col=expr, ...] + [WHERE condition] + [ORDER BY ...] + [LIMIT count] +``` + +`table_references` can include JOINs — same grammar as SELECT's FROM clause. + +### PostgreSQL Syntax + +```sql +UPDATE [ONLY] table_name [[AS] alias] + SET col=expr [, col=expr, ...] + [FROM from_list] + [WHERE condition] + [RETURNING expr_list] +``` + +### AST Structure + +``` +NODE_UPDATE_STMT + ├── NODE_STMT_OPTIONS (LOW_PRIORITY, IGNORE) + ├── NODE_TABLE_REF (target table — for both dialects, the primary table being updated) + ├── NODE_FROM_CLAUSE (MySQL: additional JOINed table refs; PostgreSQL: FROM join source) + │ (distinguished by position: always comes after SET clause for PostgreSQL, + │ comes as part of the initial table refs for MySQL) + │ (MySQL multi-table: flags field has FLAG_UPDATE_TARGET_TABLES = 0x01) + ├── NODE_UPDATE_SET_CLAUSE + │ ├── NODE_UPDATE_SET_ITEM (col = expr) + │ └── NODE_UPDATE_SET_ITEM (col = expr) + ├── NODE_WHERE_CLAUSE + ├── NODE_ORDER_BY_CLAUSE (MySQL only) + ├── NODE_LIMIT_CLAUSE (MySQL only) + └── NODE_RETURNING_CLAUSE (PostgreSQL) +``` + +For MySQL multi-table UPDATE, the table references (with JOINs) reuse the shared `TableRefParser` methods. For MySQL, the JOINed tables appear as children of the first `NODE_FROM_CLAUSE` (before SET). For PostgreSQL, the single target table is a `NODE_TABLE_REF`, and the optional `FROM` clause (after SET, before WHERE) is a separate `NODE_FROM_CLAUSE` child. The emitter checks the statement type to determine emission order. + +--- + +## DELETE Deep Parser + +### MySQL Syntax + +```sql +-- Single-table: +DELETE [LOW_PRIORITY] [QUICK] [IGNORE] FROM table_name + [WHERE condition] + [ORDER BY ...] + [LIMIT count] + +-- Multi-table form 1: +DELETE [LOW_PRIORITY] [QUICK] [IGNORE] t1, t2 + FROM table_references + [WHERE condition] + +-- Multi-table form 2: +DELETE [LOW_PRIORITY] [QUICK] [IGNORE] FROM t1, t2 + USING table_references + [WHERE condition] +``` + +### PostgreSQL Syntax + +```sql +DELETE FROM [ONLY] table_name [[AS] alias] + [USING using_list] + [WHERE condition] + [RETURNING expr_list] +``` + +### AST Structure + +``` +NODE_DELETE_STMT + ├── NODE_STMT_OPTIONS (LOW_PRIORITY, QUICK, IGNORE) + ├── NODE_TABLE_REF (target table(s)) + ├── NODE_FROM_CLAUSE (multi-table MySQL: source tables with JOINs) + ├── NODE_DELETE_USING_CLAUSE (MySQL USING or PostgreSQL USING) + ├── NODE_WHERE_CLAUSE + ├── NODE_ORDER_BY_CLAUSE (MySQL single-table only) + ├── NODE_LIMIT_CLAUSE (MySQL single-table only) + └── NODE_RETURNING_CLAUSE (PostgreSQL) +``` + +--- + +## Compound Query Parser (UNION/INTERSECT/EXCEPT) + +### Syntax + +```sql +select_stmt { UNION | INTERSECT | EXCEPT } [ALL] select_stmt + [{ UNION | INTERSECT | EXCEPT } [ALL] select_stmt ...] + [ORDER BY ...] [LIMIT ...] + +-- With parenthesized nesting: +(SELECT ...) UNION ALL (SELECT ... INTERSECT SELECT ...) ORDER BY ... LIMIT ... +``` + +### Precedence + +Per SQL standard: INTERSECT binds tighter than UNION and EXCEPT. So: + +```sql +SELECT 1 UNION SELECT 2 INTERSECT SELECT 3 +-- Parses as: SELECT 1 UNION (SELECT 2 INTERSECT SELECT 3) +``` + +Implemented via precedence levels: +- INTERSECT: higher precedence +- UNION, EXCEPT: lower precedence (same level, left-associative) + +### AST Structure + +``` +NODE_COMPOUND_QUERY + ├── NODE_SET_OPERATION (value="UNION ALL") + │ ├── NODE_SELECT_STMT (left) + │ └── NODE_SELECT_STMT (right) + ├── NODE_ORDER_BY_CLAUSE (applies to whole compound) + └── NODE_LIMIT_CLAUSE (applies to whole compound) +``` + +For nested compounds: +``` +NODE_COMPOUND_QUERY + └── NODE_SET_OPERATION (value="UNION") + ├── NODE_SELECT_STMT (left) + └── NODE_SET_OPERATION (value="INTERSECT") + ├── NODE_SELECT_STMT + └── NODE_SELECT_STMT +``` + +### Integration + +A new `CompoundQueryParser` class sits above `SelectParser`. The `parse_select()` method in `Parser` is updated to call `CompoundQueryParser` instead of `SelectParser` directly. + +`CompoundQueryParser` works as follows: +1. Parse the first operand: if `(`, consume it, parse inner compound recursively, expect `)`. Otherwise, call `SelectParser::parse()` for a single SELECT. +2. Check for set operator (UNION/INTERSECT/EXCEPT). If none, return the single SELECT as-is. +3. If found, enter a Pratt-like precedence loop: parse the operator, parse the next operand, build `NODE_SET_OPERATION` nodes respecting INTERSECT > UNION/EXCEPT precedence. +4. After the compound, parse optional trailing ORDER BY / LIMIT (applies to whole result). +5. Wrap in `NODE_COMPOUND_QUERY` and return. + +This layering means `SelectParser` is unchanged — it still parses a single SELECT statement. The compound logic is entirely in `CompoundQueryParser`, which is a separate header-only template. + +``` +include/sql_parser/ + compound_query_parser.h — UNION/INTERSECT/EXCEPT with precedence +``` + +--- + +## Query Digest / Normalization + +### API + +```cpp +template +class Digest { +public: + Digest(Arena& arena); + + // From a parsed AST (Tier 1) + DigestResult compute(const AstNode* ast); + + // From raw SQL (works for any statement, falls back to token-level for Tier 2) + DigestResult compute(const char* sql, size_t len); +}; + +struct DigestResult { + StringRef normalized; // "SELECT * FROM t WHERE id = ?" + uint64_t hash; // 64-bit hash +}; +``` + +### Normalization Rules + +1. **Literals → `?`:** Replace `NODE_LITERAL_INT`, `NODE_LITERAL_FLOAT`, `NODE_LITERAL_STRING` with `?` +2. **IN list collapsing:** `IN (?, ?, ?)` → `IN (?)` (ProxySQL convention — multiple values produce the same digest) +3. **Keyword uppercasing:** All SQL keywords emitted in uppercase canonical form +4. **Whitespace normalization:** Single space between tokens, no leading/trailing +5. **Comment stripping:** Comments already stripped by tokenizer, so this is free +6. **Backtick/quote stripping:** Identifiers emitted without quotes in digest (optional, configurable) + +### Token-Level Fallback (Tier 2) + +For statements without a full AST (Tier 2 or parse failures), the digest works at the token level: + +1. Tokenize the input +2. Walk tokens, emitting each: + - Keywords → uppercase + - Identifiers → as-is + - Literals (TK_INTEGER, TK_FLOAT, TK_STRING) → `?` + - Operators/punctuation → as-is +3. Collapse consecutive `?` in IN/VALUES lists +4. Hash the result + +This ensures digest works for ALL queries, even those the parser doesn't deeply understand. + +### Hash Function + +64-bit FNV-1a — simple, fast, no external dependency, good distribution. Computed incrementally as the normalized string is built (no second pass). + +--- + +## Emitter Extensions + +New `emit_*` methods for each new node type, following the same pattern as existing SET/SELECT emission: + +- `emit_insert_stmt`, `emit_values_clause`, `emit_values_row`, `emit_on_duplicate_key`, `emit_on_conflict`, `emit_returning` +- `emit_update_stmt`, `emit_update_set_clause`, `emit_update_set_item` +- `emit_delete_stmt`, `emit_delete_using` +- `emit_compound_query`, `emit_set_operation` + +The `RETURNING` emitter is shared across INSERT/UPDATE/DELETE. + +**Digest mode** is a constructor flag on the emitter: + +```cpp +enum class EmitMode : uint8_t { NORMAL, DIGEST }; + +Emitter(Arena& arena, EmitMode mode = EmitMode::NORMAL, + const ParamBindings* bindings = nullptr); +``` + +In digest mode, the following methods change behavior: +- `emit_value()` / `emit_string_literal()` — for literal nodes (`NODE_LITERAL_INT`, `NODE_LITERAL_FLOAT`, `NODE_LITERAL_STRING`), emit `?` instead of actual value +- `emit_in_list()` — emit `IN (?)` regardless of how many values, collapsing the list +- `emit_values_row()` — emit `(?, ?, ...)` matching column count but with `?` for all values +- `emit_placeholder()` — emit `?` (same as normal mode, already a placeholder) +- All keyword text emitted in uppercase (e.g., `SELECT`, `FROM`, `WHERE`) +- `emit_alias()` — skip aliases in digest mode (aliases don't affect query semantics for routing) + +Methods that do NOT change in digest mode: structural emission (FROM, JOIN, WHERE, GROUP BY, ORDER BY, LIMIT, etc.) remains identical since the query structure matters for digest grouping. + +--- + +## New Token Additions + +```cpp +// Needed for new syntax: +TK_DELAYED, +TK_HIGH_PRIORITY, +TK_DUPLICATE, +TK_KEY, +TK_CONFLICT, +TK_DO, +TK_NOTHING, +TK_RETURNING, +TK_ONLY, // already exists in enum, verify in keyword tables +TK_EXCEPT, +TK_INTERSECT, +TK_CONSTRAINT, +// Note: DEFAULT VALUES uses existing TK_DEFAULT + TK_VALUES (two-token approach, no compound token needed) +// Note: TK_UNION and TK_OF already exist from Plan 3 +``` + +--- + +## Implementation Plans (separate) + +This spec should be implemented across 5 plans: + +1. **Plan 7: Shared table ref parser + INSERT deep parser** — Extract TableRefParser from SelectParser, then INSERT/REPLACE with all syntax, emitter, tests. Closes #5. +2. **Plan 8: UPDATE deep parser** — full UPDATE syntax, emitter, tests. Closes #6. +3. **Plan 9: DELETE deep parser** — full DELETE syntax, emitter, tests. Closes #7. +4. **Plan 10: Compound queries** — CompoundQueryParser with UNION/INTERSECT/EXCEPT nesting, emitter, tests. Closes #8. +5. **Plan 11: Query digest** — Digest module with both AST and token-level modes, tests. Closes #9. + +**Dependencies:** Plan 7 must come first (extracts shared TableRefParser). Plans 8-9 depend on Plan 7's TableRefParser but are independent of each other. Plan 10 is independent of 7-9. Plan 11 benefits from all prior plans being complete but works with Tier 2 token-level fallback. + +--- + +## Performance Targets + +| Operation | Target | +|---|---| +| INSERT parse (simple VALUES) | <500ns | +| INSERT parse (multi-row + ON DUPLICATE KEY) | <2us | +| UPDATE parse (simple) | <500ns | +| DELETE parse (simple) | <300ns | +| Compound UNION (2 simple SELECTs) | <1us | +| Query digest (simple SELECT) | <500ns | +| Query digest (token-level, Tier 2) | <200ns | diff --git a/include/sql_parser/arena.h b/include/sql_parser/arena.h new file mode 100644 index 0000000..233fe6e --- /dev/null +++ b/include/sql_parser/arena.h @@ -0,0 +1,56 @@ +#ifndef SQL_PARSER_ARENA_H +#define SQL_PARSER_ARENA_H + +#include "sql_parser/common.h" +#include +#include +#include +#include + +namespace sql_parser { + +class Arena { +public: + explicit Arena(size_t block_size = 65536, size_t max_size = 1048576); + ~Arena(); + + Arena(const Arena&) = delete; + Arena& operator=(const Arena&) = delete; + Arena(Arena&&) = delete; + Arena& operator=(Arena&&) = delete; + + void* allocate(size_t bytes); + + template + T* allocate_typed() { + void* mem = allocate(sizeof(T)); + if (!mem) return nullptr; + return new (mem) T{}; + } + + StringRef allocate_string(const char* src, uint32_t len); + + void reset(); + + size_t bytes_used() const; + +private: + struct Block { + Block* next; + size_t capacity; + size_t used; + char* data() { return reinterpret_cast(this) + sizeof(Block); } + }; + + Block* allocate_block(size_t capacity); + + Block* primary_; + Block* current_; + size_t block_size_; + size_t max_size_; + size_t total_allocated_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_ARENA_H diff --git a/include/sql_parser/ast.h b/include/sql_parser/ast.h new file mode 100644 index 0000000..f51bfa4 --- /dev/null +++ b/include/sql_parser/ast.h @@ -0,0 +1,53 @@ +#ifndef SQL_PARSER_AST_H +#define SQL_PARSER_AST_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include +#include + +namespace sql_parser { + +struct AstNode { + AstNode* first_child; + AstNode* next_sibling; + const char* value_ptr; + uint32_t value_len; + NodeType type; + uint16_t flags; + + StringRef value() const { return StringRef{value_ptr, value_len}; } + + void set_value(StringRef ref) { + value_ptr = ref.ptr; + value_len = ref.len; + } + + void add_child(AstNode* child) { + if (!child) return; + if (!first_child) { + first_child = child; + return; + } + AstNode* last = first_child; + while (last->next_sibling) last = last->next_sibling; + last->next_sibling = child; + } +}; +static_assert(sizeof(AstNode) == 32, "AstNode must be 32 bytes"); +static_assert(std::is_trivially_copyable_v); + +inline AstNode* make_node(Arena& arena, NodeType type, StringRef value = {}, + uint16_t flags = 0) { + AstNode* node = arena.allocate_typed(); + if (!node) return nullptr; + node->type = type; + node->flags = flags; + node->value_ptr = value.ptr; + node->value_len = value.len; + return node; +} + +} // namespace sql_parser + +#endif // SQL_PARSER_AST_H diff --git a/include/sql_parser/common.h b/include/sql_parser/common.h new file mode 100644 index 0000000..74242b6 --- /dev/null +++ b/include/sql_parser/common.h @@ -0,0 +1,181 @@ +#ifndef SQL_PARSER_COMMON_H +#define SQL_PARSER_COMMON_H + +#include +#include +#include + +namespace sql_parser { + +// -- Dialect -- + +enum class Dialect : uint8_t { + MySQL, + PostgreSQL +}; + +// -- StringRef: zero-copy view into input buffer -- + +struct StringRef { + const char* ptr = nullptr; + uint32_t len = 0; + + bool empty() const { return len == 0; } + + bool equals_ci(const char* s, uint32_t slen) const { + if (len != slen) return false; + for (uint32_t i = 0; i < len; ++i) { + char a = ptr[i]; + char b = s[i]; + if (a >= 'A' && a <= 'Z') a += 32; + if (b >= 'A' && b <= 'Z') b += 32; + if (a != b) return false; + } + return true; + } + + bool operator==(const StringRef& o) const { + return len == o.len && (ptr == o.ptr || std::memcmp(ptr, o.ptr, len) == 0); + } + bool operator!=(const StringRef& o) const { return !(*this == o); } +}; +static_assert(std::is_trivially_copyable_v); + +// Case-insensitive comparison for keyword lookup (used by keyword tables) +inline int ci_cmp(const char* a, uint32_t alen, const char* b, uint8_t blen) { + uint32_t minlen = alen < blen ? alen : blen; + for (uint32_t i = 0; i < minlen; ++i) { + char ca = a[i]; + char cb = b[i]; + if (ca >= 'a' && ca <= 'z') ca -= 32; + if (cb >= 'a' && cb <= 'z') cb -= 32; + if (ca != cb) return ca < cb ? -1 : 1; + } + if (alen < blen) return -1; + if (alen > blen) return 1; + return 0; +} + +// -- Flags for NODE_SET_OPERATION -- +static constexpr uint16_t FLAG_SET_OP_ALL = 0x01; + +// -- Statement type (always set, even for PARTIAL/ERROR) -- + +enum class StmtType : uint8_t { + UNKNOWN = 0, + SELECT, + INSERT, + UPDATE, + DELETE_STMT, + REPLACE, + SET, + USE, + SHOW, + BEGIN, + START_TRANSACTION, + COMMIT, + ROLLBACK, + SAVEPOINT, + PREPARE, + EXECUTE, + DEALLOCATE, + CREATE, + ALTER, + DROP, + TRUNCATE, + GRANT, + REVOKE, + LOCK, + UNLOCK, + LOAD_DATA, + RESET, +}; + +// -- AST node types -- + +enum class NodeType : uint16_t { + NODE_UNKNOWN = 0, + + // Tier 2 lightweight nodes + NODE_STATEMENT, + NODE_TABLE_REF, + NODE_SCHEMA_REF, + NODE_IDENTIFIER, + NODE_QUALIFIED_NAME, + + // Tier 1 nodes (SELECT) + NODE_SELECT_STMT, + NODE_SELECT_OPTIONS, + NODE_SELECT_ITEM_LIST, + NODE_SELECT_ITEM, + NODE_FROM_CLAUSE, + NODE_JOIN_CLAUSE, + NODE_WHERE_CLAUSE, + NODE_GROUP_BY_CLAUSE, + NODE_HAVING_CLAUSE, + NODE_ORDER_BY_CLAUSE, + NODE_ORDER_BY_ITEM, + NODE_LIMIT_CLAUSE, + NODE_LOCKING_CLAUSE, + NODE_INTO_CLAUSE, + NODE_ALIAS, + + // Tier 1 nodes (SET) + NODE_SET_STMT, + NODE_SET_NAMES, + NODE_SET_CHARSET, + NODE_SET_TRANSACTION, + NODE_VAR_ASSIGNMENT, + NODE_VAR_TARGET, + + // Expression nodes + NODE_EXPRESSION, + NODE_BINARY_OP, + NODE_UNARY_OP, + NODE_FUNCTION_CALL, + NODE_LITERAL_INT, + NODE_LITERAL_FLOAT, + NODE_LITERAL_STRING, + NODE_LITERAL_NULL, + NODE_PLACEHOLDER, + NODE_SUBQUERY, + NODE_COLUMN_REF, + NODE_ASTERISK, + NODE_IS_NULL, + NODE_IS_NOT_NULL, + NODE_BETWEEN, + NODE_IN_LIST, + NODE_CASE_WHEN, + + // INSERT nodes + NODE_INSERT_STMT, + NODE_INSERT_COLUMNS, // (col1, col2, ...) + NODE_VALUES_CLAUSE, // VALUES keyword wrapper + NODE_VALUES_ROW, // single (val1, val2, ...) row + NODE_INSERT_SET_CLAUSE, // MySQL INSERT ... SET col=val form + NODE_ON_DUPLICATE_KEY, // MySQL ON DUPLICATE KEY UPDATE + NODE_ON_CONFLICT, // PostgreSQL ON CONFLICT + NODE_CONFLICT_TARGET, // PostgreSQL conflict target (cols or ON CONSTRAINT) + NODE_CONFLICT_ACTION, // DO UPDATE SET ... or DO NOTHING + NODE_RETURNING_CLAUSE, // PostgreSQL RETURNING expr_list + + // UPDATE nodes + NODE_UPDATE_STMT, + NODE_UPDATE_SET_CLAUSE, // SET col=expr, col=expr in UPDATE context + + // DELETE nodes + NODE_DELETE_STMT, + NODE_DELETE_USING_CLAUSE, // PostgreSQL USING or MySQL USING form + + // Compound query nodes + NODE_COMPOUND_QUERY, // root for UNION/INTERSECT/EXCEPT + NODE_SET_OPERATION, // operator (UNION, INTERSECT, EXCEPT) with ALL flag + + // Shared + NODE_STMT_OPTIONS, // LOW_PRIORITY, IGNORE, QUICK, DELAYED, etc. + NODE_UPDATE_SET_ITEM, // single col=expr pair (shared by INSERT SET and UPDATE SET) +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_COMMON_H diff --git a/include/sql_parser/compound_query_parser.h b/include/sql_parser/compound_query_parser.h new file mode 100644 index 0000000..015b4ed --- /dev/null +++ b/include/sql_parser/compound_query_parser.h @@ -0,0 +1,304 @@ +#ifndef SQL_PARSER_COMPOUND_QUERY_PARSER_H +#define SQL_PARSER_COMPOUND_QUERY_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/select_parser.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class CompoundQueryParser { +public: + CompoundQueryParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {} + + // Parse a compound query (or a plain SELECT if no set operator follows). + // Returns NODE_SELECT_STMT for plain selects, NODE_COMPOUND_QUERY for compounds. + AstNode* parse() { + AstNode* result = parse_compound_expr(0); + if (!result) return nullptr; + + // If the result is a set operation, wrap in COMPOUND_QUERY and parse trailing clauses + if (result->type == NodeType::NODE_SET_OPERATION) { + AstNode* compound = make_node(arena_, NodeType::NODE_COMPOUND_QUERY); + if (!compound) return nullptr; + compound->add_child(result); + + // Parse trailing ORDER BY (applies to whole compound) + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) compound->add_child(order_by); + } + + // Parse trailing LIMIT (applies to whole compound) + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) compound->add_child(limit); + } + + return compound; + } + + // No set operator found -- return the bare SELECT as-is. + // Since we used compound_mode, ORDER BY/LIMIT/FOR weren't consumed. + // Parse them now and attach to the SELECT node. + if (result->type == NodeType::NODE_SELECT_STMT) { + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) result->add_child(order_by); + } + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) result->add_child(limit); + } + // FOR UPDATE / FOR SHARE + if (tok_.peek().type == TokenType::TK_FOR) { + tok_.skip(); + AstNode* lock = make_node(arena_, NodeType::NODE_LOCKING_CLAUSE); + if (lock) { + Token strength = tok_.next_token(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, strength.text)); + result->add_child(lock); + } + } + } + return result; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + + // Precedence levels + static constexpr int PREC_UNION_EXCEPT = 1; + static constexpr int PREC_INTERSECT = 2; + + // Get the precedence of a set operator token, or 0 if not a set operator + static int get_set_op_precedence(TokenType type) { + switch (type) { + case TokenType::TK_UNION: return PREC_UNION_EXCEPT; + case TokenType::TK_EXCEPT: return PREC_UNION_EXCEPT; + case TokenType::TK_INTERSECT: return PREC_INTERSECT; + default: return 0; + } + } + + // Check if a token is a set operator + static bool is_set_operator(TokenType type) { + return type == TokenType::TK_UNION || + type == TokenType::TK_INTERSECT || + type == TokenType::TK_EXCEPT; + } + + // Parse a compound expression with minimum precedence (Pratt-style) + AstNode* parse_compound_expr(int min_prec) { + AstNode* left = parse_operand(); + if (!left) return nullptr; + + while (true) { + Token t = tok_.peek(); + int prec = get_set_op_precedence(t.type); + if (prec == 0 || prec <= min_prec) break; + + // Consume the set operator + tok_.skip(); + StringRef op_text = t.text; + + // Check for optional ALL + uint16_t flags = 0; + if (tok_.peek().type == TokenType::TK_ALL) { + tok_.skip(); + flags = FLAG_SET_OP_ALL; + } + + // Parse right operand with current precedence as min (left-associative) + AstNode* right = parse_compound_expr(prec); + if (!right) return nullptr; + + // Build NODE_SET_OPERATION with left and right as children + AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text); + if (!setop) return nullptr; + setop->flags = flags; + setop->add_child(left); + setop->add_child(right); + + left = setop; + } + + return left; + } + + // Parse a single operand: parenthesized compound or plain SELECT + AstNode* parse_operand() { + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); // consume '(' + + // Could be a parenthesized compound query or a parenthesized SELECT + AstNode* inner = nullptr; + if (tok_.peek().type == TokenType::TK_SELECT || + tok_.peek().type == TokenType::TK_LPAREN) { + // Parse the inner compound expression recursively + // Need to consume SELECT keyword first if present + if (tok_.peek().type == TokenType::TK_SELECT) { + tok_.skip(); // consume SELECT + // Create a SelectParser that will parse from after SELECT + SelectParser sp(tok_, arena_, true); + AstNode* select = sp.parse(); + + // Check if a set operator follows inside the parens + if (is_set_operator(tok_.peek().type)) { + // There's a compound inside the parens + inner = continue_compound_from(select, 0); + } else { + // Single SELECT inside parens -- parse ORDER BY/LIMIT + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* ob = parse_order_by(); + if (ob) select->add_child(ob); + } + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* lim = parse_limit(); + if (lim) select->add_child(lim); + } + inner = select; + } + } else { + // Nested parenthesized: ((SELECT ...)) + inner = parse_compound_expr(0); + } + } + + // Expect closing ')' + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); + } + + return inner; + } + + // Not parenthesized -- must be a plain SELECT + // Consume SELECT keyword if present (already consumed by classifier + // for the first call, but present for subsequent SELECTs in compound) + if (tok_.peek().type == TokenType::TK_SELECT) { + tok_.skip(); + } + // Use compound_mode=true so ORDER BY/LIMIT aren't consumed + SelectParser sp(tok_, arena_, true); + return sp.parse(); + } + + // Continue parsing compound from an already-parsed left operand + AstNode* continue_compound_from(AstNode* left, int min_prec) { + if (!left) return nullptr; + + while (true) { + Token t = tok_.peek(); + int prec = get_set_op_precedence(t.type); + if (prec == 0 || prec <= min_prec) break; + + tok_.skip(); + StringRef op_text = t.text; + + uint16_t flags = 0; + if (tok_.peek().type == TokenType::TK_ALL) { + tok_.skip(); + flags = FLAG_SET_OP_ALL; + } + + // Inside parens, operand must start with SELECT or ( + AstNode* right = nullptr; + if (tok_.peek().type == TokenType::TK_SELECT) { + tok_.skip(); + SelectParser sp(tok_, arena_, true); + AstNode* rsel = sp.parse(); + // Check for more operators at higher precedence + right = continue_compound_from(rsel, prec); + } else if (tok_.peek().type == TokenType::TK_LPAREN) { + right = parse_operand(); // handles nested parens + right = continue_compound_from(right, prec); + } + + if (!right) return nullptr; + + AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text); + if (!setop) return nullptr; + setop->flags = flags; + setop->add_child(left); + setop->add_child(right); + + left = setop; + } + + return left; + } + + // Parse trailing ORDER BY for compound result + AstNode* parse_order_by() { + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (!order_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + + // Optional ASC/DESC + Token dir = tok_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tok_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + + order_by->add_child(item); + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return order_by; + } + + // Parse trailing LIMIT for compound result + AstNode* parse_limit() { + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (!limit) return nullptr; + + AstNode* first = expr_parser_.parse(); + if (first) limit->add_child(first); + + if (tok_.peek().type == TokenType::TK_OFFSET) { + tok_.skip(); + AstNode* offset = expr_parser_.parse(); + if (offset) limit->add_child(offset); + } else if (tok_.peek().type == TokenType::TK_COMMA) { + // MySQL: LIMIT offset, count + tok_.skip(); + AstNode* count = expr_parser_.parse(); + if (count) limit->add_child(count); + } + + return limit; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_COMPOUND_QUERY_PARSER_H diff --git a/include/sql_parser/delete_parser.h b/include/sql_parser/delete_parser.h new file mode 100644 index 0000000..867d277 --- /dev/null +++ b/include/sql_parser/delete_parser.h @@ -0,0 +1,352 @@ +#ifndef SQL_PARSER_DELETE_PARSER_H +#define SQL_PARSER_DELETE_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" + +namespace sql_parser { + +// Flags on NODE_DELETE_STMT +static constexpr uint16_t FLAG_DELETE_MULTI_TABLE = 0x01; // multi-table form +static constexpr uint16_t FLAG_DELETE_FORM2 = 0x02; // MySQL form 2 (DELETE FROM ... USING) + +template +class DeleteParser { +public: + DeleteParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_) {} + + // Parse DELETE statement (DELETE keyword already consumed). + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_DELETE_STMT); + if (!root) return nullptr; + + if constexpr (D == Dialect::MySQL) { + return parse_mysql(root); + } else { + return parse_pgsql(root); + } + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + + // ---- MySQL DELETE ---- + // Single-table: DELETE [LOW_PRIORITY] [QUICK] [IGNORE] FROM table [WHERE] [ORDER BY] [LIMIT] + // Multi-table form 1: DELETE [opts] t1, t2 FROM table_refs [WHERE] + // Multi-table form 2: DELETE [opts] FROM t1, t2 USING table_refs [WHERE] + + AstNode* parse_mysql(AstNode* root) { + // Options: LOW_PRIORITY, QUICK, IGNORE + AstNode* opts = parse_stmt_options(); + if (opts) root->add_child(opts); + + if (tok_.peek().type == TokenType::TK_FROM) { + // Could be single-table or multi-table form 2 + tok_.skip(); // consume FROM + + // Parse the first table reference + AstNode* first_table = parse_simple_table_ref(); + if (!first_table) return root; + + // Check if comma follows (target list) or if USING follows + if (tok_.peek().type == TokenType::TK_COMMA || tok_.peek().type == TokenType::TK_USING) { + // Could be multi-table form 2: DELETE FROM t1[, t2] USING ... + // Or single-table with comma would be unusual. Check for USING after table list. + // Collect all target tables + root->add_child(first_table); + + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_table = parse_simple_table_ref(); + if (next_table) root->add_child(next_table); + } + + if (tok_.peek().type == TokenType::TK_USING) { + // Multi-table form 2 + tok_.skip(); // consume USING + root->flags = FLAG_DELETE_MULTI_TABLE | FLAG_DELETE_FORM2; + + // Parse source table references (with JOINs) + AstNode* using_clause = make_node(arena_, NodeType::NODE_DELETE_USING_CLAUSE); + AstNode* from = table_ref_parser_.parse_from_clause(); + if (from) { + // Move children of FROM_CLAUSE into USING_CLAUSE + for (AstNode* c = from->first_child; c; ) { + AstNode* next = c->next_sibling; + c->next_sibling = nullptr; + using_clause->add_child(c); + c = next; + } + } + root->add_child(using_clause); + + // WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + } else { + // Single-table (just one target table, no USING) + // Parse optional WHERE, ORDER BY, LIMIT + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) root->add_child(order_by); + } + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) root->add_child(limit); + } + } + } else { + // Single-table DELETE: DELETE FROM table [WHERE] [ORDER BY] [LIMIT] + root->add_child(first_table); + + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) root->add_child(order_by); + } + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) root->add_child(limit); + } + } + } else { + // Multi-table form 1: DELETE t1[, t2] FROM table_refs [WHERE] + root->flags = FLAG_DELETE_MULTI_TABLE; + + // Parse target table list + AstNode* first_target = parse_simple_table_ref(); + if (first_target) root->add_child(first_target); + + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_target = parse_simple_table_ref(); + if (next_target) root->add_child(next_target); + } + + // Expect FROM + if (tok_.peek().type == TokenType::TK_FROM) { + tok_.skip(); + // Parse source table references (with JOINs) + AstNode* from = table_ref_parser_.parse_from_clause(); + if (from) root->add_child(from); + } + + // WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + } + + return root; + } + + // ---- PostgreSQL DELETE ---- + // DELETE FROM [ONLY] table [[AS] alias] [USING using_list] [WHERE] [RETURNING] + + AstNode* parse_pgsql(AstNode* root) { + // Consume FROM + if (tok_.peek().type == TokenType::TK_FROM) { + tok_.skip(); + } + + // Optional ONLY keyword + if (tok_.peek().type == TokenType::TK_ONLY) { + AstNode* opts = make_node(arena_, NodeType::NODE_STMT_OPTIONS); + Token only_tok = tok_.next_token(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, only_tok.text)); + root->add_child(opts); + } + + // Single table reference with optional alias + AstNode* table_ref = table_ref_parser_.parse_table_reference(); + if (table_ref) root->add_child(table_ref); + + // USING clause + if (tok_.peek().type == TokenType::TK_USING) { + tok_.skip(); + AstNode* using_clause = make_node(arena_, NodeType::NODE_DELETE_USING_CLAUSE); + + // Parse table list (comma-separated, potentially with JOINs) + while (true) { + AstNode* tref = table_ref_parser_.parse_table_reference(); + if (tref) using_clause->add_child(tref); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + root->add_child(using_clause); + } + + // WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + + // RETURNING + if (tok_.peek().type == TokenType::TK_RETURNING) { + AstNode* ret = parse_returning(); + if (ret) root->add_child(ret); + } + + return root; + } + + // ---- Shared helpers ---- + + // Parse a simple table reference (name or schema.name, no alias parsing for target tables) + AstNode* parse_simple_table_ref() { + AstNode* ref = make_node(arena_, NodeType::NODE_TABLE_REF); + if (!ref) return nullptr; + + Token name = tok_.next_token(); + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token table_name = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, table_name.text)); + ref->add_child(qname); + } else { + ref->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + } + + return ref; + } + + // Parse MySQL options: LOW_PRIORITY, QUICK, IGNORE + AstNode* parse_stmt_options() { + AstNode* opts = nullptr; + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_LOW_PRIORITY || + t.type == TokenType::TK_QUICK || + t.type == TokenType::TK_IGNORE) { + if (!opts) opts = make_node(arena_, NodeType::NODE_STMT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else { + break; + } + } + return opts; + } + + // Parse WHERE clause + AstNode* parse_where_clause() { + AstNode* where = make_node(arena_, NodeType::NODE_WHERE_CLAUSE); + if (!where) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) where->add_child(expr); + return where; + } + + // Parse ORDER BY clause + AstNode* parse_order_by() { + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (!order_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + + // Optional ASC/DESC + Token dir = tok_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tok_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + + order_by->add_child(item); + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return order_by; + } + + // Parse LIMIT clause + AstNode* parse_limit() { + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (!limit) return nullptr; + + AstNode* count = expr_parser_.parse(); + if (count) limit->add_child(count); + + return limit; + } + + // Parse PostgreSQL RETURNING clause + AstNode* parse_returning() { + if (tok_.peek().type != TokenType::TK_RETURNING) return nullptr; + tok_.skip(); // RETURNING + + AstNode* ret = make_node(arena_, NodeType::NODE_RETURNING_CLAUSE); + if (!ret) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + ret->add_child(expr); + + // Check for optional alias + Token next = tok_.peek(); + if (next.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + ret->add_child(make_node(arena_, NodeType::NODE_ALIAS, alias_name.text)); + } + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + + return ret; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_DELETE_PARSER_H diff --git a/include/sql_parser/digest.h b/include/sql_parser/digest.h new file mode 100644 index 0000000..2acd52a --- /dev/null +++ b/include/sql_parser/digest.h @@ -0,0 +1,287 @@ +#ifndef SQL_PARSER_DIGEST_H +#define SQL_PARSER_DIGEST_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/emitter.h" +#include "sql_parser/string_builder.h" +#include + +namespace sql_parser { + +struct DigestResult { + StringRef normalized; // "SELECT * FROM t WHERE id = ?" + uint64_t hash; // 64-bit FNV-1a hash +}; + +// FNV-1a 64-bit hash -- simple, fast, no external dependency +struct FnvHash { + static constexpr uint64_t FNV_OFFSET_BASIS = 14695981039346656037ULL; + static constexpr uint64_t FNV_PRIME = 1099511628211ULL; + + uint64_t state = FNV_OFFSET_BASIS; + + void update(const char* data, size_t len) { + for (size_t i = 0; i < len; ++i) { + state ^= static_cast(static_cast(data[i])); + state *= FNV_PRIME; + } + } + + void update_char(char c) { + state ^= static_cast(static_cast(c)); + state *= FNV_PRIME; + } + + uint64_t finish() const { return state; } +}; + +template +class Digest { +public: + explicit Digest(Arena& arena) : arena_(arena) {} + + // From a parsed AST (Tier 1) -- uses Emitter in DIGEST mode + DigestResult compute(const AstNode* ast) { + Emitter emitter(arena_, EmitMode::DIGEST); + emitter.emit(ast); + StringRef normalized = emitter.result(); + FnvHash hasher; + hasher.update(normalized.ptr, normalized.len); + return DigestResult{normalized, hasher.finish()}; + } + + // From raw SQL (works for any statement) -- uses token-level fallback + DigestResult compute(const char* sql, size_t len) { + return compute_token_level(sql, len); + } + +private: + Arena& arena_; + + // Helper: check if a token type is a keyword (not an identifier, literal, or operator) + static bool is_keyword_token(TokenType type) { + // Keywords start at TK_SELECT and go through TK_EXCEPT + return static_cast(type) >= static_cast(TokenType::TK_SELECT); + } + + // Helper: check if a token type is a literal value that should become ? + static bool is_literal_token(TokenType type) { + return type == TokenType::TK_INTEGER || + type == TokenType::TK_FLOAT || + type == TokenType::TK_STRING; + } + + // Helper: uppercase a character + static char to_upper(char c) { + return (c >= 'a' && c <= 'z') ? (c - 32) : c; + } + + // Append token text uppercased to StringBuilder + static void append_upper(StringBuilder& sb, const char* ptr, uint32_t len) { + for (uint32_t i = 0; i < len; ++i) { + sb.append_char(to_upper(ptr[i])); + } + } + + // Determine if we need a space before this token given the previous token type + static bool needs_space_before(TokenType prev, TokenType cur) { + // Never space after ( or before ) + if (prev == TokenType::TK_LPAREN) return false; + if (cur == TokenType::TK_RPAREN) return false; + // No space before or after dot + if (prev == TokenType::TK_DOT || cur == TokenType::TK_DOT) return false; + // No space before comma + if (cur == TokenType::TK_COMMA) return false; + // No space after @ or @@ + if (prev == TokenType::TK_AT || prev == TokenType::TK_DOUBLE_AT) return false; + // No space before @ + if (cur == TokenType::TK_AT) return false; + return true; + } + + // Emit a single token to the string builder, uppercasing keywords, replacing literals with ? + void emit_token(StringBuilder& sb, const Token& t, TokenType prev) { + bool space = (prev != TokenType::TK_EOF) && needs_space_before(prev, t.type); + if (space) sb.append_char(' '); + + if (is_literal_token(t.type)) { + sb.append_char('?'); + } else if (is_keyword_token(t.type)) { + append_upper(sb, t.text.ptr, t.text.len); + } else if (t.type == TokenType::TK_IDENTIFIER) { + sb.append(t.text.ptr, t.text.len); + } else if (t.type == TokenType::TK_QUESTION) { + sb.append_char('?'); + } else if (t.type == TokenType::TK_COMMA) { + sb.append(",", 1); + } else { + // All other tokens: emit as-is + sb.append(t.text.ptr, t.text.len); + } + } + + // Skip tokens inside parentheses until matching close paren. Returns last token type consumed. + void skip_paren_contents(Tokenizer& tok) { + int depth = 1; + while (depth > 0) { + Token inner = tok.next_token(); + if (inner.type == TokenType::TK_EOF) break; + if (inner.type == TokenType::TK_LPAREN) depth++; + if (inner.type == TokenType::TK_RPAREN) depth--; + } + } + + // Token-level digest: walk tokens, normalize, hash + DigestResult compute_token_level(const char* sql, size_t len) { + Tokenizer tok; + tok.reset(sql, len); + StringBuilder sb(arena_); + TokenType prev = TokenType::TK_EOF; + + // We collect tokens into a small buffer for lookahead patterns + // Main loop: read token, check for special patterns, emit + + Token t = tok.next_token(); + + while (t.type != TokenType::TK_EOF && t.type != TokenType::TK_SEMICOLON) { + + // Pattern: IN (...) -> collapse to IN (?) + if (t.type == TokenType::TK_IN) { + emit_token(sb, t, prev); + prev = t.type; + + Token next = tok.next_token(); + if (next.type == TokenType::TK_LPAREN) { + // Emit " (" + emit_token(sb, next, prev); + prev = next.type; + // Collapse contents to single ? + bool emitted_q = false; + int depth = 1; + while (depth > 0) { + Token inner = tok.next_token(); + if (inner.type == TokenType::TK_EOF) break; + if (inner.type == TokenType::TK_LPAREN) { depth++; continue; } + if (inner.type == TokenType::TK_RPAREN) { + depth--; + if (depth == 0) { + sb.append_char(')'); + prev = TokenType::TK_RPAREN; + break; + } + continue; + } + if (!emitted_q) { + sb.append_char('?'); + prev = TokenType::TK_QUESTION; + emitted_q = true; + } + } + t = tok.next_token(); + continue; + } else { + // IN not followed by ( -- process next token normally + t = next; + continue; + } + } + + // Pattern: VALUES (...), (...), ... -> collapse to VALUES (?, ?, ...) + if (t.type == TokenType::TK_VALUES) { + emit_token(sb, t, prev); + prev = t.type; + + Token next = tok.next_token(); + if (next.type == TokenType::TK_LPAREN) { + // Emit the opening paren + emit_token(sb, next, prev); + prev = next.type; + + // Emit first row contents with ? for each value slot + int depth = 1; + while (depth > 0) { + Token inner = tok.next_token(); + if (inner.type == TokenType::TK_EOF) break; + if (inner.type == TokenType::TK_LPAREN) { + depth++; + continue; + } + if (inner.type == TokenType::TK_RPAREN) { + depth--; + if (depth == 0) { + sb.append_char(')'); + prev = TokenType::TK_RPAREN; + break; + } + continue; + } + if (inner.type == TokenType::TK_COMMA && depth == 1) { + sb.append(", ", 2); + prev = TokenType::TK_COMMA; + continue; + } + // Emit ? for literals and existing placeholders + if (is_literal_token(inner.type) || inner.type == TokenType::TK_QUESTION) { + // Only emit ? once per value slot (skip if prev already emitted one) + if (prev == TokenType::TK_LPAREN || prev == TokenType::TK_COMMA) { + sb.append_char('?'); + prev = TokenType::TK_QUESTION; + } + } + } + + // Skip additional rows: , (...) + while (true) { + Token peek = tok.next_token(); + if (peek.type == TokenType::TK_COMMA) { + Token peek2 = tok.next_token(); + if (peek2.type == TokenType::TK_LPAREN) { + // Skip this entire row + skip_paren_contents(tok); + continue; + } else { + // Comma but not followed by ( -- it's not another row + // Emit the comma and continue with peek2 + sb.append(",", 1); + prev = TokenType::TK_COMMA; + t = peek2; + goto emit_normal; + } + } else { + // Not a comma - done with VALUES rows + t = peek; + goto emit_normal; + } + } + } else { + // VALUES not followed by ( -- process next normally + t = next; + continue; + } + } + + emit_normal: + // Check if we've reached the end (can happen after VALUES/IN lookahead) + if (t.type == TokenType::TK_EOF || t.type == TokenType::TK_SEMICOLON) break; + + emit_token(sb, t, prev); + prev = t.type; + // For literal tokens, record as TK_QUESTION since we emitted ? + if (is_literal_token(t.type)) prev = TokenType::TK_QUESTION; + + t = tok.next_token(); + } + + StringRef normalized = sb.finish(); + FnvHash hasher; + hasher.update(normalized.ptr, normalized.len); + return DigestResult{normalized, hasher.finish()}; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_DIGEST_H diff --git a/include/sql_parser/emitter.h b/include/sql_parser/emitter.h new file mode 100644 index 0000000..f1db702 --- /dev/null +++ b/include/sql_parser/emitter.h @@ -0,0 +1,934 @@ +#ifndef SQL_PARSER_EMITTER_H +#define SQL_PARSER_EMITTER_H + +#include "sql_parser/common.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/string_builder.h" +#include "sql_parser/parse_result.h" +#include + +namespace sql_parser { + +enum class EmitMode : uint8_t { NORMAL, DIGEST }; + +template +class Emitter { +public: + explicit Emitter(Arena& arena, EmitMode mode = EmitMode::NORMAL, + const ParamBindings* bindings = nullptr) + : sb_(arena), bindings_(bindings), placeholder_index_(0), mode_(mode) {} + + void emit(const AstNode* node) { + if (!node) return; + emit_node(node); + } + + StringRef result() { return sb_.finish(); } + +private: + StringBuilder sb_; + const ParamBindings* bindings_; + uint16_t placeholder_index_; + EmitMode mode_; + + void emit_node(const AstNode* node) { + switch (node->type) { + // ---- SET statement ---- + case NodeType::NODE_SET_STMT: emit_set_stmt(node); break; + case NodeType::NODE_SET_NAMES: emit_set_names(node); break; + case NodeType::NODE_SET_CHARSET: emit_set_charset(node); break; + case NodeType::NODE_SET_TRANSACTION: emit_set_transaction(node); break; + case NodeType::NODE_VAR_ASSIGNMENT: emit_var_assignment(node); break; + case NodeType::NODE_VAR_TARGET: emit_var_target(node); break; + + // ---- SELECT statement ---- + case NodeType::NODE_SELECT_STMT: emit_select_stmt(node); break; + case NodeType::NODE_SELECT_OPTIONS: emit_select_options(node); break; + case NodeType::NODE_SELECT_ITEM_LIST:emit_select_item_list(node); break; + case NodeType::NODE_SELECT_ITEM: emit_select_item(node); break; + case NodeType::NODE_FROM_CLAUSE: emit_from_clause(node); break; + case NodeType::NODE_JOIN_CLAUSE: emit_join_clause(node); break; + case NodeType::NODE_WHERE_CLAUSE: emit_where_clause(node); break; + case NodeType::NODE_GROUP_BY_CLAUSE: emit_group_by(node); break; + case NodeType::NODE_HAVING_CLAUSE: emit_having(node); break; + case NodeType::NODE_ORDER_BY_CLAUSE: emit_order_by(node); break; + case NodeType::NODE_ORDER_BY_ITEM: emit_order_by_item(node); break; + case NodeType::NODE_LIMIT_CLAUSE: emit_limit(node); break; + case NodeType::NODE_LOCKING_CLAUSE: emit_locking(node); break; + case NodeType::NODE_INTO_CLAUSE: emit_into(node); break; + + // ---- INSERT statement ---- + case NodeType::NODE_INSERT_STMT: emit_insert_stmt(node); break; + case NodeType::NODE_INSERT_COLUMNS: emit_insert_columns(node); break; + case NodeType::NODE_VALUES_CLAUSE: emit_values_clause(node); break; + case NodeType::NODE_VALUES_ROW: emit_values_row(node); break; + case NodeType::NODE_INSERT_SET_CLAUSE: emit_insert_set_clause(node); break; + case NodeType::NODE_ON_DUPLICATE_KEY: emit_on_duplicate_key(node); break; + case NodeType::NODE_ON_CONFLICT: emit_on_conflict(node); break; + case NodeType::NODE_CONFLICT_TARGET: emit_conflict_target(node); break; + case NodeType::NODE_CONFLICT_ACTION: emit_conflict_action(node); break; + case NodeType::NODE_RETURNING_CLAUSE: emit_returning(node); break; + case NodeType::NODE_STMT_OPTIONS: emit_stmt_options(node); break; + case NodeType::NODE_UPDATE_SET_ITEM: emit_update_set_item(node); break; + + // ---- Compound query ---- + case NodeType::NODE_COMPOUND_QUERY: emit_compound_query(node); break; + case NodeType::NODE_SET_OPERATION: emit_set_operation(node); break; + + // ---- DELETE statement ---- + case NodeType::NODE_DELETE_STMT: emit_delete_stmt(node); break; + case NodeType::NODE_DELETE_USING_CLAUSE: emit_delete_using(node); break; + + // ---- UPDATE statement ---- + case NodeType::NODE_UPDATE_STMT: emit_update_stmt(node); break; + case NodeType::NODE_UPDATE_SET_CLAUSE: emit_update_set_clause(node); break; + + // ---- Table references ---- + case NodeType::NODE_TABLE_REF: emit_table_ref(node); break; + case NodeType::NODE_ALIAS: emit_alias(node); break; + case NodeType::NODE_QUALIFIED_NAME: emit_qualified_name(node); break; + + // ---- Expressions ---- + case NodeType::NODE_BINARY_OP: emit_binary_op(node); break; + case NodeType::NODE_UNARY_OP: emit_unary_op(node); break; + case NodeType::NODE_FUNCTION_CALL: emit_function_call(node); break; + case NodeType::NODE_IS_NULL: emit_is_null(node); break; + case NodeType::NODE_IS_NOT_NULL: emit_is_not_null(node); break; + case NodeType::NODE_BETWEEN: emit_between(node); break; + case NodeType::NODE_IN_LIST: emit_in_list(node); break; + case NodeType::NODE_CASE_WHEN: emit_case_when(node); break; + case NodeType::NODE_SUBQUERY: emit_value(node); break; + + // ---- Leaf nodes (emit value directly) ---- + case NodeType::NODE_PLACEHOLDER: + emit_placeholder(node); break; + + // ---- Leaf nodes (emit value directly) ---- + case NodeType::NODE_LITERAL_INT: + case NodeType::NODE_LITERAL_FLOAT: + if (mode_ == EmitMode::DIGEST) { sb_.append_char('?'); break; } + emit_value(node); break; + case NodeType::NODE_LITERAL_NULL: + case NodeType::NODE_COLUMN_REF: + case NodeType::NODE_ASTERISK: + case NodeType::NODE_IDENTIFIER: + emit_value(node); break; + + case NodeType::NODE_LITERAL_STRING: + if (mode_ == EmitMode::DIGEST) { sb_.append_char('?'); break; } + emit_string_literal(node); break; + + default: + emit_value(node); break; + } + } + + void emit_value(const AstNode* node) { + sb_.append(node->value_ptr, node->value_len); + } + + void emit_string_literal(const AstNode* node) { + sb_.append_char('\''); + sb_.append(node->value_ptr, node->value_len); + sb_.append_char('\''); + } + + void emit_placeholder(const AstNode* node) { + if (bindings_ && placeholder_index_ < bindings_->count) { + const BoundValue& bv = bindings_->values[placeholder_index_]; + ++placeholder_index_; + switch (bv.type) { + case BoundValue::INT: + { char buf[32]; int n = snprintf(buf, sizeof(buf), "%lld", (long long)bv.int_val); + sb_.append(buf, n); } + break; + case BoundValue::FLOAT: + { char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", (double)bv.float32_val); + sb_.append(buf, n); } + break; + case BoundValue::DOUBLE: + { char buf[64]; int n = snprintf(buf, sizeof(buf), "%g", bv.float64_val); + sb_.append(buf, n); } + break; + case BoundValue::STRING: + case BoundValue::DATETIME: + case BoundValue::DECIMAL: + sb_.append_char('\''); + sb_.append(bv.str_val); + sb_.append_char('\''); + break; + case BoundValue::BLOB: + sb_.append(bv.str_val); + break; + case BoundValue::NULL_VAL: + sb_.append("NULL", 4); + break; + } + } else { + // No binding available -- emit placeholder as-is + emit_value(node); + } + } + + // ---- SET ---- + + void emit_set_stmt(const AstNode* node) { + sb_.append("SET "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_set_names(const AstNode* node) { + sb_.append("NAMES "); + const AstNode* charset = node->first_child; + if (charset) emit_node(charset); + const AstNode* collation = charset ? charset->next_sibling : nullptr; + if (collation) { + sb_.append(" COLLATE "); + emit_node(collation); + } + } + + void emit_set_charset(const AstNode* node) { + // Always emit as CHARACTER SET (canonical form) + sb_.append("CHARACTER SET "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_set_transaction(const AstNode* node) { + sb_.append("TRANSACTION "); + const AstNode* child = node->first_child; + // First child may be scope (GLOBAL/SESSION) + if (child && child->value_len > 0) { + StringRef val = child->value(); + // Check if this is a scope keyword + if (val.equals_ci("GLOBAL", 6) || val.equals_ci("SESSION", 7) || + val.equals_ci("LOCAL", 5)) { + // This was already emitted before TRANSACTION by the SET stmt emitter + child = child->next_sibling; + } + } + if (child) { + StringRef val = child->value(); + // Check if this is an isolation level or access mode + if (val.equals_ci("READ ONLY", 9) || val.equals_ci("READ WRITE", 10)) { + emit_node(child); + } else { + // It's an isolation level value + sb_.append("ISOLATION LEVEL "); + emit_node(child); + } + } + } + + void emit_var_assignment(const AstNode* node) { + const AstNode* target = node->first_child; + const AstNode* rhs = target ? target->next_sibling : nullptr; + + if (target) emit_node(target); + sb_.append(" = "); + if (rhs) emit_node(rhs); + } + + void emit_var_target(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + // ---- SELECT ---- + + void emit_select_stmt(const AstNode* node) { + sb_.append("SELECT "); + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + } + } + + void emit_select_options(const AstNode* node) { + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + sb_.append_char(' '); + } + } + + void emit_select_item_list(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_select_item(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + const AstNode* alias = expr ? expr->next_sibling : nullptr; + if (alias && alias->type == NodeType::NODE_ALIAS) { + emit_node(alias); + } + } + + void emit_from_clause(const AstNode* node) { + sb_.append(" FROM "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_JOIN_CLAUSE) { + sb_.append_char(' '); + emit_node(child); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + + void emit_table_ref(const AstNode* node) { + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + } + } + + void emit_alias(const AstNode* node) { + if (mode_ == EmitMode::DIGEST) return; // skip aliases in digest mode + sb_.append(" AS "); + emit_value(node); + } + + void emit_qualified_name(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char('.'); + first = false; + emit_node(child); + } + } + + void emit_join_clause(const AstNode* node) { + // Join type stored in node value + emit_value(node); + sb_.append_char(' '); + // Children: table_ref, [ON expr | USING (...)] + const AstNode* table = node->first_child; + if (table) { + emit_node(table); + } + const AstNode* condition = table ? table->next_sibling : nullptr; + if (condition) { + if (condition->type == NodeType::NODE_IDENTIFIER && + condition->value_len == 5 && + std::memcmp(condition->value_ptr, "USING", 5) == 0) { + sb_.append(" USING ("); + bool first = true; + for (const AstNode* col = condition->first_child; col; col = col->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(col); + } + sb_.append_char(')'); + } else { + sb_.append(" ON "); + emit_node(condition); + } + } + } + + void emit_where_clause(const AstNode* node) { + sb_.append(" WHERE "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_group_by(const AstNode* node) { + sb_.append(" GROUP BY "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_having(const AstNode* node) { + sb_.append(" HAVING "); + if (node->first_child) emit_node(node->first_child); + } + + void emit_order_by(const AstNode* node) { + sb_.append(" ORDER BY "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_order_by_item(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + const AstNode* dir = expr ? expr->next_sibling : nullptr; + if (dir) { + sb_.append_char(' '); + emit_node(dir); + } + } + + void emit_limit(const AstNode* node) { + sb_.append(" LIMIT "); + const AstNode* first_val = node->first_child; + if (first_val) emit_node(first_val); + const AstNode* second_val = first_val ? first_val->next_sibling : nullptr; + if (second_val) { + sb_.append(" OFFSET "); + emit_node(second_val); + } + } + + void emit_locking(const AstNode* node) { + sb_.append(" FOR "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + void emit_into(const AstNode* node) { + sb_.append(" INTO "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + // ---- INSERT ---- + + void emit_insert_stmt(const AstNode* node) { + // Check FLAG_REPLACE + if (node->flags & 0x01) { + sb_.append("REPLACE"); + } else { + sb_.append("INSERT"); + } + + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + switch (child->type) { + case NodeType::NODE_STMT_OPTIONS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_TABLE_REF: + sb_.append(" INTO "); + emit_node(child); + break; + case NodeType::NODE_INSERT_COLUMNS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_VALUES_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_SELECT_STMT: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_INSERT_SET_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_ON_DUPLICATE_KEY: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_ON_CONFLICT: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_RETURNING_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + default: + sb_.append_char(' '); + emit_node(child); + break; + } + } + } + + void emit_stmt_options(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append_char(' '); + first = false; + emit_node(child); + } + } + + void emit_insert_columns(const AstNode* node) { + sb_.append_char('('); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + sb_.append_char(')'); + } + + void emit_values_clause(const AstNode* node) { + // Check for DEFAULT VALUES (value stored in node) + if (node->value_len > 0) { + emit_value(node); // "DEFAULT VALUES" + return; + } + sb_.append("VALUES "); + if (mode_ == EmitMode::DIGEST) { + // Collapse to single row in digest mode + if (node->first_child) emit_node(node->first_child); + } else { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + + void emit_values_row(const AstNode* node) { + sb_.append_char('('); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + sb_.append_char(')'); + } + + void emit_insert_set_clause(const AstNode* node) { + sb_.append("SET "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_update_set_item(const AstNode* node) { + const AstNode* col = node->first_child; + const AstNode* val = col ? col->next_sibling : nullptr; + if (col) emit_node(col); + sb_.append(" = "); + if (val) emit_node(val); + } + + void emit_on_duplicate_key(const AstNode* node) { + sb_.append("ON DUPLICATE KEY UPDATE "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + void emit_on_conflict(const AstNode* node) { + sb_.append("ON CONFLICT"); + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + sb_.append_char(' '); + emit_node(child); + } + } + + void emit_conflict_target(const AstNode* node) { + if (node->value_len > 0) { + // ON CONSTRAINT name + emit_value(node); // "ON CONSTRAINT" + sb_.append_char(' '); + if (node->first_child) emit_node(node->first_child); + } else { + // (col1, col2, ...) + sb_.append_char('('); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + sb_.append_char(')'); + } + } + + void emit_conflict_action(const AstNode* node) { + sb_.append("DO "); + StringRef action_type{node->value_ptr, node->value_len}; + if (action_type.equals_ci("NOTHING", 7)) { + sb_.append("NOTHING"); + } else if (action_type.equals_ci("UPDATE", 6)) { + sb_.append("UPDATE SET "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_WHERE_CLAUSE) { + emit_node(child); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + } + + void emit_returning(const AstNode* node) { + sb_.append("RETURNING "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_ALIAS) { + emit_node(child); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + + // ---- UPDATE ---- + + void emit_update_stmt(const AstNode* node) { + sb_.append("UPDATE"); + + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + switch (child->type) { + case NodeType::NODE_STMT_OPTIONS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_TABLE_REF: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_FROM_CLAUSE: + { + // Check if FROM_CLAUSE is before SET_CLAUSE (MySQL multi-table) + // or after (PostgreSQL FROM) + bool is_before_set = false; + for (const AstNode* s = child->next_sibling; s; s = s->next_sibling) { + if (s->type == NodeType::NODE_UPDATE_SET_CLAUSE) { + is_before_set = true; + break; + } + } + if (is_before_set) { + // MySQL multi-table: emit table refs without FROM keyword + sb_.append_char(' '); + bool first = true; + for (const AstNode* c = child->first_child; c; c = c->next_sibling) { + if (c->type == NodeType::NODE_JOIN_CLAUSE) { + sb_.append_char(' '); + emit_node(c); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(c); + } + } + } else { + // PostgreSQL FROM clause (emit_from_clause adds " FROM " prefix) + emit_node(child); + } + } + break; + case NodeType::NODE_UPDATE_SET_CLAUSE: + sb_.append(" SET "); + emit_update_set_clause_inner(child); + break; + case NodeType::NODE_WHERE_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_ORDER_BY_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_LIMIT_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_RETURNING_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + default: + sb_.append_char(' '); + emit_node(child); + break; + } + } + } + + void emit_update_set_clause(const AstNode* node) { + sb_.append("SET "); + emit_update_set_clause_inner(node); + } + + void emit_update_set_clause_inner(const AstNode* node) { + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + + // ---- DELETE ---- + + void emit_delete_stmt(const AstNode* node) { + sb_.append("DELETE"); + + // Flags determine the form + bool is_multi = (node->flags & 0x01) != 0; + bool is_form2 = (node->flags & 0x02) != 0; + + if (is_multi && !is_form2) { + // Form 1: DELETE [opts] t1, t2 FROM table_refs [WHERE] + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + switch (child->type) { + case NodeType::NODE_STMT_OPTIONS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_TABLE_REF: { + // Check if previous sibling is also TABLE_REF (needs comma) + bool prev_is_table = false; + for (const AstNode* p = node->first_child; p != child; p = p->next_sibling) { + if (p->type == NodeType::NODE_TABLE_REF) prev_is_table = true; + else prev_is_table = false; + } + if (prev_is_table) sb_.append(", "); + else sb_.append_char(' '); + emit_node(child); + break; + } + case NodeType::NODE_FROM_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_WHERE_CLAUSE: + emit_node(child); + break; + default: + sb_.append_char(' '); + emit_node(child); + break; + } + } + } else if (is_multi && is_form2) { + // Form 2: DELETE [opts] FROM t1, t2 USING table_refs [WHERE] + bool first_table = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + switch (child->type) { + case NodeType::NODE_STMT_OPTIONS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_TABLE_REF: + if (first_table) { + sb_.append(" FROM "); + first_table = false; + } else { + sb_.append(", "); + } + emit_node(child); + break; + case NodeType::NODE_DELETE_USING_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_WHERE_CLAUSE: + emit_node(child); + break; + default: + sb_.append_char(' '); + emit_node(child); + break; + } + } + } else { + // Single-table or PostgreSQL + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + switch (child->type) { + case NodeType::NODE_STMT_OPTIONS: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_TABLE_REF: + sb_.append(" FROM "); + emit_node(child); + break; + case NodeType::NODE_DELETE_USING_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + case NodeType::NODE_WHERE_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_ORDER_BY_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_LIMIT_CLAUSE: + emit_node(child); + break; + case NodeType::NODE_RETURNING_CLAUSE: + sb_.append_char(' '); + emit_node(child); + break; + default: + sb_.append_char(' '); + emit_node(child); + break; + } + } + } + } + + void emit_delete_using(const AstNode* node) { + sb_.append("USING "); + bool first = true; + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_JOIN_CLAUSE) { + sb_.append_char(' '); + emit_node(child); + } else { + if (!first) sb_.append(", "); + first = false; + emit_node(child); + } + } + } + + // ---- Compound query ---- + + void emit_compound_query(const AstNode* node) { + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_SET_OPERATION) { + emit_set_operation(child); + } else { + // Trailing ORDER BY or LIMIT + emit_node(child); + } + } + } + + void emit_set_operation(const AstNode* node) { + const AstNode* left = node->first_child; + const AstNode* right = left ? left->next_sibling : nullptr; + + if (left) emit_node(left); + + // Emit the operator: " UNION ", " UNION ALL ", " INTERSECT ", etc. + sb_.append_char(' '); + emit_value(node); // operator keyword text (UNION, INTERSECT, EXCEPT) + if (node->flags & FLAG_SET_OP_ALL) { + sb_.append(" ALL"); + } + sb_.append_char(' '); + + if (right) emit_node(right); + } + + // ---- Expressions ---- + + void emit_binary_op(const AstNode* node) { + const AstNode* left = node->first_child; + const AstNode* right = left ? left->next_sibling : nullptr; + if (left) emit_node(left); + sb_.append_char(' '); + emit_value(node); + sb_.append_char(' '); + if (right) emit_node(right); + } + + void emit_unary_op(const AstNode* node) { + emit_value(node); + // Add space for keyword operators like NOT, no space for - or + + if (node->value_len > 1) sb_.append_char(' '); + if (node->first_child) emit_node(node->first_child); + } + + void emit_function_call(const AstNode* node) { + emit_value(node); + sb_.append_char('('); + bool first = true; + for (const AstNode* arg = node->first_child; arg; arg = arg->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(arg); + } + sb_.append_char(')'); + } + + void emit_is_null(const AstNode* node) { + if (node->first_child) emit_node(node->first_child); + sb_.append(" IS NULL"); + } + + void emit_is_not_null(const AstNode* node) { + if (node->first_child) emit_node(node->first_child); + sb_.append(" IS NOT NULL"); + } + + void emit_between(const AstNode* node) { + const AstNode* expr = node->first_child; + const AstNode* low = expr ? expr->next_sibling : nullptr; + const AstNode* high = low ? low->next_sibling : nullptr; + if (expr) emit_node(expr); + sb_.append(" BETWEEN "); + if (low) emit_node(low); + sb_.append(" AND "); + if (high) emit_node(high); + } + + void emit_in_list(const AstNode* node) { + const AstNode* expr = node->first_child; + if (expr) emit_node(expr); + sb_.append(" IN ("); + if (mode_ == EmitMode::DIGEST) { + sb_.append_char('?'); + } else { + bool first = true; + for (const AstNode* val = expr ? expr->next_sibling : nullptr; val; val = val->next_sibling) { + if (!first) sb_.append(", "); + first = false; + emit_node(val); + } + } + sb_.append_char(')'); + } + + void emit_case_when(const AstNode* node) { + sb_.append("CASE "); + for (const AstNode* child = node->first_child; child; child = child->next_sibling) { + emit_node(child); + sb_.append_char(' '); + } + sb_.append("END"); + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_EMITTER_H diff --git a/include/sql_parser/expression_parser.h b/include/sql_parser/expression_parser.h new file mode 100644 index 0000000..e5430eb --- /dev/null +++ b/include/sql_parser/expression_parser.h @@ -0,0 +1,461 @@ +#ifndef SQL_PARSER_EXPRESSION_PARSER_H +#define SQL_PARSER_EXPRESSION_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" + +namespace sql_parser { + +// Operator precedence levels for Pratt parsing +enum class Precedence : uint8_t { + NONE = 0, + OR, // OR + AND, // AND + NOT, // NOT (prefix) + COMPARISON, // =, <, >, <=, >=, !=, <>, IS, LIKE, IN, BETWEEN + ADDITION, // +, - + MULTIPLICATION,// *, /, % + UNARY, // - (prefix), NOT + POSTFIX, // IS NULL, IS NOT NULL + CALL, // function() + PRIMARY, // literals, identifiers +}; + +template +class ExpressionParser { +public: + ExpressionParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena) {} + + // Parse an expression with minimum precedence 0 + AstNode* parse(Precedence min_prec = Precedence::NONE) { + AstNode* left = parse_atom(); + if (!left) return nullptr; + + while (true) { + Precedence prec = infix_precedence(tok_.peek().type); + if (prec <= min_prec) break; + + left = parse_infix(left, prec); + if (!left) return nullptr; + } + + return left; + } + +private: + Tokenizer& tok_; + Arena& arena_; + + // Parse a primary expression (atom) + AstNode* parse_atom() { + Token t = tok_.peek(); + + switch (t.type) { + case TokenType::TK_INTEGER: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_INT, t.text); + } + case TokenType::TK_FLOAT: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_FLOAT, t.text); + } + case TokenType::TK_STRING: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_STRING, t.text); + } + case TokenType::TK_NULL: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_NULL, t.text); + } + case TokenType::TK_TRUE: + case TokenType::TK_FALSE: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_LITERAL_INT, t.text); + } + case TokenType::TK_DEFAULT: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_IDENTIFIER, t.text); + } + case TokenType::TK_ASTERISK: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_ASTERISK, t.text); + } + case TokenType::TK_QUESTION: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_PLACEHOLDER, t.text); + } + case TokenType::TK_DOLLAR_NUM: { + tok_.skip(); + return make_node(arena_, NodeType::NODE_PLACEHOLDER, t.text); + } + case TokenType::TK_AT: { + // User variable: @name + tok_.skip(); + Token name = tok_.next_token(); + // Build @name as a single COLUMN_REF with combined text + // value_ptr points to @ in original input, len covers @name + StringRef full{t.text.ptr, + static_cast((name.text.ptr + name.text.len) - t.text.ptr)}; + return make_node(arena_, NodeType::NODE_COLUMN_REF, full); + } + case TokenType::TK_DOUBLE_AT: { + // System variable: @@name or @@scope.name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{t.text.ptr, + static_cast((name.text.ptr + name.text.len) - t.text.ptr)}; + AstNode* node = make_node(arena_, NodeType::NODE_COLUMN_REF, full); + // Check for @@scope.name + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token var_name = tok_.next_token(); + full = StringRef{t.text.ptr, + static_cast((var_name.text.ptr + var_name.text.len) - t.text.ptr)}; + node->value_ptr = full.ptr; + node->value_len = full.len; + } + return node; + } + case TokenType::TK_MINUS: { + // Unary minus + tok_.skip(); + AstNode* operand = parse(Precedence::UNARY); + if (!operand) return nullptr; + AstNode* node = make_node(arena_, NodeType::NODE_UNARY_OP, t.text); + node->add_child(operand); + return node; + } + case TokenType::TK_PLUS: { + // Unary plus + tok_.skip(); + return parse(Precedence::UNARY); + } + case TokenType::TK_NOT: { + tok_.skip(); + AstNode* operand = parse(Precedence::NOT); + if (!operand) return nullptr; + AstNode* node = make_node(arena_, NodeType::NODE_UNARY_OP, t.text); + node->add_child(operand); + return node; + } + case TokenType::TK_EXISTS: { + tok_.skip(); + // EXISTS (subquery) + AstNode* node = make_node(arena_, NodeType::NODE_SUBQUERY); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + skip_to_matching_paren(); + } + return node; + } + case TokenType::TK_CASE: { + tok_.skip(); + return parse_case(); + } + case TokenType::TK_LPAREN: { + tok_.skip(); + // Could be subquery: (SELECT ...) + if (tok_.peek().type == TokenType::TK_SELECT) { + // Subquery — for now, skip to matching paren + AstNode* node = make_node(arena_, NodeType::NODE_SUBQUERY); + skip_to_matching_paren(); + return node; + } + AstNode* expr = parse(); + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); + } + return expr; + } + case TokenType::TK_IDENTIFIER: { + tok_.skip(); + return parse_identifier_or_function(t); + } + // Keywords that can appear as identifiers in expression context + // (e.g., column names that happen to be keywords) + default: { + if (is_keyword_as_identifier(t.type)) { + tok_.skip(); + return parse_identifier_or_function(t); + } + return nullptr; // not an expression + } + } + } + + AstNode* parse_identifier_or_function(const Token& name_token) { + // Check for function call: name( + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); // consume ( + AstNode* func = make_node(arena_, NodeType::NODE_FUNCTION_CALL, name_token.text); + // Parse argument list + if (tok_.peek().type != TokenType::TK_RPAREN) { + while (true) { + AstNode* arg = parse(); + if (arg) func->add_child(arg); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); + } + return func; + } + + // Check for qualified name: table.column + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); // consume dot + Token col = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name_token.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + return qname; + } + + return make_node(arena_, NodeType::NODE_COLUMN_REF, name_token.text); + } + + // Infix precedence for a token type. + // Returns NONE if not an infix operator (stops the Pratt loop). + static Precedence infix_precedence(TokenType type) { + switch (type) { + case TokenType::TK_OR: return Precedence::OR; + case TokenType::TK_AND: return Precedence::AND; + case TokenType::TK_NOT: return Precedence::COMPARISON; // NOT IN/BETWEEN/LIKE + case TokenType::TK_EQUAL: + case TokenType::TK_NOT_EQUAL: + case TokenType::TK_LESS: + case TokenType::TK_GREATER: + case TokenType::TK_LESS_EQUAL: + case TokenType::TK_GREATER_EQUAL: + case TokenType::TK_LIKE: return Precedence::COMPARISON; + case TokenType::TK_IS: return Precedence::COMPARISON; + case TokenType::TK_IN: return Precedence::COMPARISON; + case TokenType::TK_BETWEEN: return Precedence::COMPARISON; + case TokenType::TK_PLUS: + case TokenType::TK_MINUS: return Precedence::ADDITION; + case TokenType::TK_ASTERISK: + case TokenType::TK_SLASH: + case TokenType::TK_PERCENT: return Precedence::MULTIPLICATION; + case TokenType::TK_DOUBLE_PIPE: return Precedence::ADDITION; // string concat + default: return Precedence::NONE; + } + } + + AstNode* parse_infix(AstNode* left, Precedence prec) { + Token op = tok_.next_token(); + + switch (op.type) { + case TokenType::TK_NOT: { + // NOT IN / NOT BETWEEN / NOT LIKE — compound negated infix + Token actual_op = tok_.peek(); + if (actual_op.type == TokenType::TK_IN) { + tok_.skip(); + AstNode* in_node = parse_in(left); + // Wrap in NOT + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(in_node); + return not_node; + } + if (actual_op.type == TokenType::TK_BETWEEN) { + tok_.skip(); + AstNode* between_node = parse_between(left); + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(between_node); + return not_node; + } + if (actual_op.type == TokenType::TK_LIKE) { + tok_.skip(); + AstNode* right = parse(prec); + AstNode* like_node = make_node(arena_, NodeType::NODE_BINARY_OP, actual_op.text); + like_node->add_child(left); + if (right) like_node->add_child(right); + AstNode* not_node = make_node(arena_, NodeType::NODE_UNARY_OP, op.text); + not_node->add_child(like_node); + return not_node; + } + // Standalone NOT in infix position — shouldn't happen, return left + return left; + } + case TokenType::TK_IS: { + // IS [NOT] NULL + bool is_not = false; + if (tok_.peek().type == TokenType::TK_NOT) { + is_not = true; + tok_.skip(); + } + if (tok_.peek().type == TokenType::TK_NULL) { + tok_.skip(); + NodeType nt = is_not ? NodeType::NODE_IS_NOT_NULL : NodeType::NODE_IS_NULL; + AstNode* node = make_node(arena_, nt); + node->add_child(left); + return node; + } + // IS TRUE / IS FALSE / IS NOT TRUE / IS NOT FALSE + if (tok_.peek().type == TokenType::TK_TRUE || tok_.peek().type == TokenType::TK_FALSE) { + Token val = tok_.next_token(); + AstNode* node = make_node(arena_, NodeType::NODE_BINARY_OP, + is_not ? StringRef{"IS NOT", 6} : StringRef{"IS", 2}); + node->add_child(left); + node->add_child(make_node(arena_, NodeType::NODE_LITERAL_INT, val.text)); + return node; + } + return left; + } + case TokenType::TK_IN: + return parse_in(left); + case TokenType::TK_BETWEEN: + return parse_between(left); + default: { + // Standard binary operator + AstNode* right = parse(prec); + if (!right) return left; + AstNode* node = make_node(arena_, NodeType::NODE_BINARY_OP, op.text); + node->add_child(left); + node->add_child(right); + return node; + } + } + } + + // IN (value_list) or IN (subquery) + AstNode* parse_in(AstNode* left) { + AstNode* node = make_node(arena_, NodeType::NODE_IN_LIST); + node->add_child(left); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_SELECT) { + AstNode* sq = make_node(arena_, NodeType::NODE_SUBQUERY); + skip_to_matching_paren(); + node->add_child(sq); + } else { + while (true) { + AstNode* val = parse(); + if (val) node->add_child(val); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + } + } + return node; + } + + // BETWEEN low AND high + AstNode* parse_between(AstNode* left) { + AstNode* node = make_node(arena_, NodeType::NODE_BETWEEN); + node->add_child(left); + AstNode* low = parse(Precedence::COMPARISON); + node->add_child(low); + if (tok_.peek().type == TokenType::TK_AND) { + tok_.skip(); + } + AstNode* high = parse(Precedence::COMPARISON); + node->add_child(high); + return node; + } + + // CASE [expr] WHEN ... THEN ... [ELSE ...] END + AstNode* parse_case() { + AstNode* node = make_node(arena_, NodeType::NODE_CASE_WHEN); + // Optional simple CASE expression: CASE expr WHEN ... + if (tok_.peek().type != TokenType::TK_WHEN) { + AstNode* case_expr = parse(); + if (case_expr) node->add_child(case_expr); + } + // WHEN ... THEN ... pairs + while (tok_.peek().type == TokenType::TK_WHEN) { + tok_.skip(); + AstNode* when_expr = parse(); + if (when_expr) node->add_child(when_expr); + if (tok_.peek().type == TokenType::TK_THEN) tok_.skip(); + AstNode* then_expr = parse(); + if (then_expr) node->add_child(then_expr); + } + // Optional ELSE + if (tok_.peek().type == TokenType::TK_ELSE) { + tok_.skip(); + AstNode* else_expr = parse(); + if (else_expr) node->add_child(else_expr); + } + // END + if (tok_.peek().type == TokenType::TK_END) tok_.skip(); + return node; + } + + // Skip tokens until matching closing paren (handles nesting) + void skip_to_matching_paren() { + int depth = 1; + while (depth > 0) { + Token t = tok_.next_token(); + if (t.type == TokenType::TK_LPAREN) ++depth; + else if (t.type == TokenType::TK_RPAREN) --depth; + else if (t.type == TokenType::TK_EOF) break; + } + } + + // Some keywords can appear as identifiers in expression context + static bool is_keyword_as_identifier(TokenType type) { + switch (type) { + // Keywords commonly used as column/table names + case TokenType::TK_COUNT: + case TokenType::TK_SUM: + case TokenType::TK_AVG: + case TokenType::TK_MIN: + case TokenType::TK_MAX: + case TokenType::TK_IF: + case TokenType::TK_VALUES: + case TokenType::TK_DATABASE: + case TokenType::TK_SCHEMA: + case TokenType::TK_TABLE: + case TokenType::TK_INDEX: + case TokenType::TK_VIEW: + case TokenType::TK_NAMES: + case TokenType::TK_CHARACTER: + case TokenType::TK_CHARSET: + case TokenType::TK_GLOBAL: + case TokenType::TK_SESSION: + case TokenType::TK_LOCAL: + case TokenType::TK_LEVEL: + case TokenType::TK_READ: + case TokenType::TK_WRITE: + case TokenType::TK_ONLY: + case TokenType::TK_TRANSACTION: + case TokenType::TK_ISOLATION: + case TokenType::TK_COMMITTED: + case TokenType::TK_UNCOMMITTED: + case TokenType::TK_REPEATABLE: + case TokenType::TK_SERIALIZABLE: + case TokenType::TK_SHARE: + case TokenType::TK_DATA: + case TokenType::TK_RESET: + case TokenType::TK_KEY: + case TokenType::TK_DO: + case TokenType::TK_NOTHING: + case TokenType::TK_CONFLICT: + case TokenType::TK_CONSTRAINT: + case TokenType::TK_RETURNING: + case TokenType::TK_DUPLICATE: + case TokenType::TK_DELAYED: + case TokenType::TK_HIGH_PRIORITY: + return true; + default: + return false; + } + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_EXPRESSION_PARSER_H diff --git a/include/sql_parser/insert_parser.h b/include/sql_parser/insert_parser.h new file mode 100644 index 0000000..420cad8 --- /dev/null +++ b/include/sql_parser/insert_parser.h @@ -0,0 +1,441 @@ +#ifndef SQL_PARSER_INSERT_PARSER_H +#define SQL_PARSER_INSERT_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" +#include "sql_parser/select_parser.h" + +namespace sql_parser { + +// Flag on NODE_INSERT_STMT to indicate REPLACE +static constexpr uint16_t FLAG_REPLACE = 0x01; + +template +class InsertParser { +public: + InsertParser(Tokenizer& tokenizer, Arena& arena, bool is_replace = false) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_), + is_replace_(is_replace) {} + + // Parse INSERT/REPLACE statement (INSERT/REPLACE keyword already consumed). + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_INSERT_STMT, {}, + is_replace_ ? FLAG_REPLACE : uint16_t(0)); + if (!root) return nullptr; + + // MySQL options: [LOW_PRIORITY | DELAYED | HIGH_PRIORITY] [IGNORE] + if constexpr (D == Dialect::MySQL) { + AstNode* opts = parse_stmt_options(); + if (opts) root->add_child(opts); + } + + // Optional INTO keyword + if (tok_.peek().type == TokenType::TK_INTO) { + tok_.skip(); + } + + // Table reference + AstNode* table_ref = table_ref_parser_.parse_table_reference(); + if (table_ref) root->add_child(table_ref); + + // Check for column list or go straight to data source + // Column list: (col1, col2, ...) + // Need to distinguish (col_list) from VALUES (row) — peek ahead + if (tok_.peek().type == TokenType::TK_LPAREN) { + // Could be column list or VALUES row without VALUES keyword + // If VALUES/SET/SELECT/DEFAULT follows later, this is a column list + // Heuristic: peek inside the parens — if followed by VALUES/SET/SELECT/DEFAULT/ON/RETURNING/;/EOF, it's columns + // Actually: column list is (identifiers), VALUES clause has VALUES keyword before parens + // So if next is LPAREN and it's NOT preceded by VALUES, it's the column list + if (tok_.peek().type == TokenType::TK_LPAREN && + !is_values_next()) { + AstNode* cols = parse_column_list(); + if (cols) root->add_child(cols); + } + } + + // Data source: VALUES | SELECT | SET | DEFAULT VALUES + Token next = tok_.peek(); + if (next.type == TokenType::TK_VALUES) { + tok_.skip(); + AstNode* values = parse_values_clause(); + if (values) root->add_child(values); + } else if (next.type == TokenType::TK_SELECT) { + // INSERT ... SELECT + tok_.skip(); // consume SELECT + SelectParser select_parser(tok_, arena_); + AstNode* select = select_parser.parse(); + if (select) root->add_child(select); + } else if constexpr (D == Dialect::MySQL) { + if (next.type == TokenType::TK_SET) { + tok_.skip(); + AstNode* set_clause = parse_insert_set_clause(); + if (set_clause) root->add_child(set_clause); + } + } + if constexpr (D == Dialect::PostgreSQL) { + if (next.type == TokenType::TK_DEFAULT) { + tok_.skip(); // consume DEFAULT + if (tok_.peek().type == TokenType::TK_VALUES) { + tok_.skip(); // consume VALUES + // Store as a VALUES clause with no rows (signals DEFAULT VALUES) + AstNode* values = make_node(arena_, NodeType::NODE_VALUES_CLAUSE, + StringRef{"DEFAULT VALUES", 14}); + root->add_child(values); + } + } + } + + // MySQL: ON DUPLICATE KEY UPDATE + if constexpr (D == Dialect::MySQL) { + if (tok_.peek().type == TokenType::TK_ON) { + AstNode* odku = parse_on_duplicate_key(); + if (odku) root->add_child(odku); + } + } + + // PostgreSQL: ON CONFLICT ... and RETURNING + if constexpr (D == Dialect::PostgreSQL) { + if (tok_.peek().type == TokenType::TK_ON) { + AstNode* oc = parse_on_conflict(); + if (oc) root->add_child(oc); + } + if (tok_.peek().type == TokenType::TK_RETURNING) { + AstNode* ret = parse_returning(); + if (ret) root->add_child(ret); + } + } + + return root; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + bool is_replace_; + + // Check if we're looking at a VALUES keyword (not a column list paren) + bool is_values_next() { + // The LPAREN is for column list, not VALUES row + // This is only called when we see LPAREN and need to decide + // A column list is always followed by VALUES/SET/SELECT/DEFAULT + // Actually the approach is simpler: if the next token is LPAREN + // and the token before was the table ref (no VALUES keyword yet), + // it's the column list. + return false; // caller only calls this when peeking at LPAREN + } + + // Parse MySQL options: LOW_PRIORITY, DELAYED, HIGH_PRIORITY, IGNORE + AstNode* parse_stmt_options() { + AstNode* opts = nullptr; + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_LOW_PRIORITY || + t.type == TokenType::TK_DELAYED || + t.type == TokenType::TK_HIGH_PRIORITY || + t.type == TokenType::TK_IGNORE) { + if (!opts) opts = make_node(arena_, NodeType::NODE_STMT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else { + break; + } + } + return opts; + } + + // Parse column list: (col1, col2, ...) + AstNode* parse_column_list() { + AstNode* cols = make_node(arena_, NodeType::NODE_INSERT_COLUMNS); + if (!cols) return nullptr; + + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); // consume ( + while (true) { + Token col = tok_.next_token(); + cols->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); // consume ) + } + } + return cols; + } + + // Parse VALUES clause: (row1), (row2), ... + AstNode* parse_values_clause() { + AstNode* values = make_node(arena_, NodeType::NODE_VALUES_CLAUSE); + if (!values) return nullptr; + + while (true) { + AstNode* row = parse_values_row(); + if (row) values->add_child(row); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return values; + } + + // Parse a single values row: (expr, expr, ...) + AstNode* parse_values_row() { + AstNode* row = make_node(arena_, NodeType::NODE_VALUES_ROW); + if (!row) return nullptr; + + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); // consume ( + while (true) { + AstNode* val = expr_parser_.parse(); + if (val) row->add_child(val); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) { + tok_.skip(); // consume ) + } + } + return row; + } + + // Parse MySQL SET form: col=val, col=val, ... + AstNode* parse_insert_set_clause() { + AstNode* set_clause = make_node(arena_, NodeType::NODE_INSERT_SET_CLAUSE); + if (!set_clause) return nullptr; + + while (true) { + AstNode* item = parse_set_item(); + if (item) set_clause->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return set_clause; + } + + // Parse a single col=expr pair + AstNode* parse_set_item() { + AstNode* item = make_node(arena_, NodeType::NODE_UPDATE_SET_ITEM); + if (!item) return nullptr; + + // Column name (may be qualified: table.col) + Token col = tok_.next_token(); + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token actual_col = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, actual_col.text)); + item->add_child(qname); + } else { + item->add_child(make_node(arena_, NodeType::NODE_COLUMN_REF, col.text)); + } + + // = sign + if (tok_.peek().type == TokenType::TK_EQUAL) { + tok_.skip(); + } + + // Expression value + AstNode* val = expr_parser_.parse(); + if (val) item->add_child(val); + + return item; + } + + // Parse MySQL ON DUPLICATE KEY UPDATE col=val, ... + AstNode* parse_on_duplicate_key() { + // Expect: ON DUPLICATE KEY UPDATE + if (tok_.peek().type != TokenType::TK_ON) return nullptr; + tok_.skip(); // ON + + if (tok_.peek().type != TokenType::TK_DUPLICATE) return nullptr; + tok_.skip(); // DUPLICATE + + if (tok_.peek().type != TokenType::TK_KEY) return nullptr; + tok_.skip(); // KEY + + if (tok_.peek().type != TokenType::TK_UPDATE) return nullptr; + tok_.skip(); // UPDATE + + AstNode* odku = make_node(arena_, NodeType::NODE_ON_DUPLICATE_KEY); + if (!odku) return nullptr; + + // Parse SET items + while (true) { + AstNode* item = parse_set_item(); + if (item) odku->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + + return odku; + } + + // Parse PostgreSQL ON CONFLICT ... + AstNode* parse_on_conflict() { + // Expect: ON CONFLICT + if (tok_.peek().type != TokenType::TK_ON) return nullptr; + tok_.skip(); // ON + + if (tok_.peek().type != TokenType::TK_CONFLICT) return nullptr; + tok_.skip(); // CONFLICT + + AstNode* oc = make_node(arena_, NodeType::NODE_ON_CONFLICT); + if (!oc) return nullptr; + + // Optional conflict target: (cols) or ON CONSTRAINT name + if (tok_.peek().type == TokenType::TK_LPAREN) { + AstNode* target = parse_conflict_target_cols(); + if (target) oc->add_child(target); + } else if (tok_.peek().type == TokenType::TK_ON) { + // ON CONSTRAINT name + AstNode* target = parse_conflict_target_constraint(); + if (target) oc->add_child(target); + } + + // DO UPDATE SET ... or DO NOTHING + if (tok_.peek().type == TokenType::TK_DO) { + AstNode* action = parse_conflict_action(); + if (action) oc->add_child(action); + } + + return oc; + } + + // Parse conflict target: (col1, col2, ...) + AstNode* parse_conflict_target_cols() { + AstNode* target = make_node(arena_, NodeType::NODE_CONFLICT_TARGET); + if (!target) return nullptr; + + tok_.skip(); // consume ( + while (true) { + Token col = tok_.next_token(); + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + + return target; + } + + // Parse ON CONSTRAINT name + AstNode* parse_conflict_target_constraint() { + AstNode* target = make_node(arena_, NodeType::NODE_CONFLICT_TARGET, + StringRef{"ON CONSTRAINT", 13}); + if (!target) return nullptr; + + tok_.skip(); // ON + if (tok_.peek().type == TokenType::TK_CONSTRAINT) { + tok_.skip(); // CONSTRAINT + } + Token name = tok_.next_token(); + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + + return target; + } + + // Parse DO UPDATE SET ... WHERE ... or DO NOTHING + AstNode* parse_conflict_action() { + if (tok_.peek().type != TokenType::TK_DO) return nullptr; + tok_.skip(); // DO + + AstNode* action = make_node(arena_, NodeType::NODE_CONFLICT_ACTION); + if (!action) return nullptr; + + if (tok_.peek().type == TokenType::TK_NOTHING) { + tok_.skip(); + action->set_value(StringRef{"NOTHING", 7}); + } else if (tok_.peek().type == TokenType::TK_UPDATE) { + tok_.skip(); // UPDATE + action->set_value(StringRef{"UPDATE", 6}); + + if (tok_.peek().type == TokenType::TK_SET) { + tok_.skip(); // SET + } + + // Parse SET items + while (true) { + AstNode* item = parse_set_item(); + if (item) action->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + + // Optional WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = make_node(arena_, NodeType::NODE_WHERE_CLAUSE); + AstNode* expr = expr_parser_.parse(); + if (expr) where->add_child(expr); + action->add_child(where); + } + } + + return action; + } + + // Parse PostgreSQL RETURNING expr_list + AstNode* parse_returning() { + if (tok_.peek().type != TokenType::TK_RETURNING) return nullptr; + tok_.skip(); // RETURNING + + AstNode* ret = make_node(arena_, NodeType::NODE_RETURNING_CLAUSE); + if (!ret) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + ret->add_child(expr); + + // Check for optional alias + Token next = tok_.peek(); + if (next.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + ret->add_child(make_node(arena_, NodeType::NODE_ALIAS, alias_name.text)); + } + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + + return ret; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_INSERT_PARSER_H diff --git a/include/sql_parser/keywords_mysql.h b/include/sql_parser/keywords_mysql.h new file mode 100644 index 0000000..b15cf61 --- /dev/null +++ b/include/sql_parser/keywords_mysql.h @@ -0,0 +1,161 @@ +#ifndef SQL_PARSER_KEYWORDS_MYSQL_H +#define SQL_PARSER_KEYWORDS_MYSQL_H + +#include "sql_parser/token.h" +#include +#include + +namespace sql_parser { +namespace mysql_keywords { + +struct KeywordEntry { + const char* text; + uint8_t len; + TokenType token; +}; + +inline constexpr KeywordEntry KEYWORDS[] = { + {"ALL", 3, TokenType::TK_ALL}, + {"ALTER", 5, TokenType::TK_ALTER}, + {"AND", 3, TokenType::TK_AND}, + {"AS", 2, TokenType::TK_AS}, + {"ASC", 3, TokenType::TK_ASC}, + {"AVG", 3, TokenType::TK_AVG}, + {"BEGIN", 5, TokenType::TK_BEGIN}, + {"BETWEEN", 7, TokenType::TK_BETWEEN}, + {"BY", 2, TokenType::TK_BY}, + {"CASE", 4, TokenType::TK_CASE}, + {"CHARACTER", 9, TokenType::TK_CHARACTER}, + {"CHARSET", 7, TokenType::TK_CHARSET}, + {"COLLATE", 7, TokenType::TK_COLLATE}, + {"COMMIT", 6, TokenType::TK_COMMIT}, + {"COMMITTED", 9, TokenType::TK_COMMITTED}, + {"CONFLICT", 8, TokenType::TK_CONFLICT}, + {"CONSTRAINT", 10, TokenType::TK_CONSTRAINT}, + {"COUNT", 5, TokenType::TK_COUNT}, + {"CREATE", 6, TokenType::TK_CREATE}, + {"CROSS", 5, TokenType::TK_CROSS}, + {"DATA", 4, TokenType::TK_DATA}, + {"DATABASE", 8, TokenType::TK_DATABASE}, + {"DEALLOCATE", 10, TokenType::TK_DEALLOCATE}, + {"DEFAULT", 7, TokenType::TK_DEFAULT}, + {"DELAYED", 7, TokenType::TK_DELAYED}, + {"DELETE", 6, TokenType::TK_DELETE}, + {"DESC", 4, TokenType::TK_DESC}, + {"DISTINCT", 8, TokenType::TK_DISTINCT}, + {"DO", 2, TokenType::TK_DO}, + {"DROP", 4, TokenType::TK_DROP}, + {"DUMPFILE", 8, TokenType::TK_DUMPFILE}, + {"DUPLICATE", 9, TokenType::TK_DUPLICATE}, + {"ELSE", 4, TokenType::TK_ELSE}, + {"END", 3, TokenType::TK_END}, + {"EXCEPT", 6, TokenType::TK_EXCEPT}, + {"EXECUTE", 7, TokenType::TK_EXECUTE}, + {"EXISTS", 6, TokenType::TK_EXISTS}, + {"FALSE", 5, TokenType::TK_FALSE}, + {"FETCH", 5, TokenType::TK_FETCH}, + {"FOR", 3, TokenType::TK_FOR}, + {"FROM", 4, TokenType::TK_FROM}, + {"FULL", 4, TokenType::TK_FULL}, + {"GLOBAL", 6, TokenType::TK_GLOBAL}, + {"GRANT", 5, TokenType::TK_GRANT}, + {"GROUP", 5, TokenType::TK_GROUP}, + {"HAVING", 6, TokenType::TK_HAVING}, + {"HIGH_PRIORITY", 13, TokenType::TK_HIGH_PRIORITY}, + {"IF", 2, TokenType::TK_IF}, + {"IGNORE", 6, TokenType::TK_IGNORE}, + {"IN", 2, TokenType::TK_IN}, + {"INDEX", 5, TokenType::TK_INDEX}, + {"INNER", 5, TokenType::TK_INNER}, + {"INSERT", 6, TokenType::TK_INSERT}, + {"INTERSECT", 9, TokenType::TK_INTERSECT}, + {"INTO", 4, TokenType::TK_INTO}, + {"IS", 2, TokenType::TK_IS}, + {"ISOLATION", 9, TokenType::TK_ISOLATION}, + {"JOIN", 4, TokenType::TK_JOIN}, + {"KEY", 3, TokenType::TK_KEY}, + {"LEFT", 4, TokenType::TK_LEFT}, + {"LEVEL", 5, TokenType::TK_LEVEL}, + {"LIKE", 4, TokenType::TK_LIKE}, + {"LIMIT", 5, TokenType::TK_LIMIT}, + {"LOAD", 4, TokenType::TK_LOAD}, + {"LOCAL", 5, TokenType::TK_LOCAL}, + {"LOCK", 4, TokenType::TK_LOCK}, + {"LOCKED", 6, TokenType::TK_LOCKED}, + {"LOW_PRIORITY", 12, TokenType::TK_LOW_PRIORITY}, + {"MAX", 3, TokenType::TK_MAX}, + {"MIN", 3, TokenType::TK_MIN}, + {"NAMES", 5, TokenType::TK_NAMES}, + {"NATURAL", 7, TokenType::TK_NATURAL}, + {"NOT", 3, TokenType::TK_NOT}, + {"NOTHING", 7, TokenType::TK_NOTHING}, + {"NOWAIT", 6, TokenType::TK_NOWAIT}, + {"NULL", 4, TokenType::TK_NULL}, + {"OF", 2, TokenType::TK_OF}, + {"OFFSET", 6, TokenType::TK_OFFSET}, + {"ON", 2, TokenType::TK_ON}, + {"ONLY", 4, TokenType::TK_ONLY}, + {"OR", 2, TokenType::TK_OR}, + {"ORDER", 5, TokenType::TK_ORDER}, + {"OUTER", 5, TokenType::TK_OUTER}, + {"OUTFILE", 7, TokenType::TK_OUTFILE}, + {"PERSIST", 7, TokenType::TK_PERSIST}, + {"PREPARE", 7, TokenType::TK_PREPARE}, + {"QUICK", 5, TokenType::TK_QUICK}, + {"READ", 4, TokenType::TK_READ}, + {"REPEATABLE", 10, TokenType::TK_REPEATABLE}, + {"REPLACE", 7, TokenType::TK_REPLACE}, + {"RESET", 5, TokenType::TK_RESET}, + {"RETURNING", 9, TokenType::TK_RETURNING}, + {"REVOKE", 6, TokenType::TK_REVOKE}, + {"RIGHT", 5, TokenType::TK_RIGHT}, + {"ROLLBACK", 8, TokenType::TK_ROLLBACK}, + {"SAVEPOINT", 9, TokenType::TK_SAVEPOINT}, + {"SCHEMA", 6, TokenType::TK_SCHEMA}, + {"SELECT", 6, TokenType::TK_SELECT}, + {"SERIALIZABLE", 12, TokenType::TK_SERIALIZABLE}, + {"SESSION", 7, TokenType::TK_SESSION}, + {"SET", 3, TokenType::TK_SET}, + {"SHARE", 5, TokenType::TK_SHARE}, + {"SHOW", 4, TokenType::TK_SHOW}, + {"SKIP", 4, TokenType::TK_SKIP}, + {"SQL_CALC_FOUND_ROWS", 19, TokenType::TK_SQL_CALC_FOUND_ROWS}, + {"START", 5, TokenType::TK_START}, + {"SUM", 3, TokenType::TK_SUM}, + {"TABLE", 5, TokenType::TK_TABLE}, + {"THEN", 4, TokenType::TK_THEN}, + {"TO", 2, TokenType::TK_TO}, + {"TRANSACTION", 11, TokenType::TK_TRANSACTION}, + {"TRUE", 4, TokenType::TK_TRUE}, + {"TRUNCATE", 8, TokenType::TK_TRUNCATE}, + {"UNION", 5, TokenType::TK_UNION}, + {"UNCOMMITTED", 11, TokenType::TK_UNCOMMITTED}, + {"UNLOCK", 6, TokenType::TK_UNLOCK}, + {"UPDATE", 6, TokenType::TK_UPDATE}, + {"USE", 3, TokenType::TK_USE}, + {"USING", 5, TokenType::TK_USING}, + {"VALUES", 6, TokenType::TK_VALUES}, + {"VIEW", 4, TokenType::TK_VIEW}, + {"WHEN", 4, TokenType::TK_WHEN}, + {"WHERE", 5, TokenType::TK_WHERE}, + {"WRITE", 5, TokenType::TK_WRITE}, +}; + +inline constexpr size_t KEYWORD_COUNT = sizeof(KEYWORDS) / sizeof(KEYWORDS[0]); + +inline TokenType lookup(const char* text, uint32_t len) { + size_t lo = 0, hi = KEYWORD_COUNT; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + int cmp = sql_parser::ci_cmp(text, len, KEYWORDS[mid].text, KEYWORDS[mid].len); + if (cmp == 0) return KEYWORDS[mid].token; + if (cmp < 0) hi = mid; + else lo = mid + 1; + } + return TokenType::TK_IDENTIFIER; +} + +} // namespace mysql_keywords +} // namespace sql_parser + +#endif // SQL_PARSER_KEYWORDS_MYSQL_H diff --git a/include/sql_parser/keywords_pgsql.h b/include/sql_parser/keywords_pgsql.h new file mode 100644 index 0000000..4af9886 --- /dev/null +++ b/include/sql_parser/keywords_pgsql.h @@ -0,0 +1,142 @@ +#ifndef SQL_PARSER_KEYWORDS_PGSQL_H +#define SQL_PARSER_KEYWORDS_PGSQL_H + +#include "sql_parser/token.h" + +namespace sql_parser { +namespace pgsql_keywords { + +struct KeywordEntry { + const char* text; + uint8_t len; + TokenType token; +}; + +inline constexpr KeywordEntry KEYWORDS[] = { + {"ALL", 3, TokenType::TK_ALL}, + {"ALTER", 5, TokenType::TK_ALTER}, + {"AND", 3, TokenType::TK_AND}, + {"AS", 2, TokenType::TK_AS}, + {"ASC", 3, TokenType::TK_ASC}, + {"AVG", 3, TokenType::TK_AVG}, + {"BEGIN", 5, TokenType::TK_BEGIN}, + {"BETWEEN", 7, TokenType::TK_BETWEEN}, + {"BY", 2, TokenType::TK_BY}, + {"CASE", 4, TokenType::TK_CASE}, + {"CHARACTER", 9, TokenType::TK_CHARACTER}, + {"COLLATE", 7, TokenType::TK_COLLATE}, + {"COMMIT", 6, TokenType::TK_COMMIT}, + {"COMMITTED", 9, TokenType::TK_COMMITTED}, + {"CONFLICT", 8, TokenType::TK_CONFLICT}, + {"CONSTRAINT", 10, TokenType::TK_CONSTRAINT}, + {"COUNT", 5, TokenType::TK_COUNT}, + {"CREATE", 6, TokenType::TK_CREATE}, + {"CROSS", 5, TokenType::TK_CROSS}, + {"DATA", 4, TokenType::TK_DATA}, + {"DATABASE", 8, TokenType::TK_DATABASE}, + {"DEALLOCATE", 10, TokenType::TK_DEALLOCATE}, + {"DEFAULT", 7, TokenType::TK_DEFAULT}, + {"DELETE", 6, TokenType::TK_DELETE}, + {"DESC", 4, TokenType::TK_DESC}, + {"DISTINCT", 8, TokenType::TK_DISTINCT}, + {"DO", 2, TokenType::TK_DO}, + {"DROP", 4, TokenType::TK_DROP}, + {"ELSE", 4, TokenType::TK_ELSE}, + {"END", 3, TokenType::TK_END}, + {"EXCEPT", 6, TokenType::TK_EXCEPT}, + {"EXECUTE", 7, TokenType::TK_EXECUTE}, + {"EXISTS", 6, TokenType::TK_EXISTS}, + {"FALSE", 5, TokenType::TK_FALSE}, + {"FETCH", 5, TokenType::TK_FETCH}, + {"FOR", 3, TokenType::TK_FOR}, + {"FROM", 4, TokenType::TK_FROM}, + {"FULL", 4, TokenType::TK_FULL}, + {"GRANT", 5, TokenType::TK_GRANT}, + {"GROUP", 5, TokenType::TK_GROUP}, + {"HAVING", 6, TokenType::TK_HAVING}, + {"IF", 2, TokenType::TK_IF}, + {"IN", 2, TokenType::TK_IN}, + {"INDEX", 5, TokenType::TK_INDEX}, + {"INNER", 5, TokenType::TK_INNER}, + {"INSERT", 6, TokenType::TK_INSERT}, + {"INTERSECT", 9, TokenType::TK_INTERSECT}, + {"INTO", 4, TokenType::TK_INTO}, + {"IS", 2, TokenType::TK_IS}, + {"ISOLATION", 9, TokenType::TK_ISOLATION}, + {"JOIN", 4, TokenType::TK_JOIN}, + {"LEFT", 4, TokenType::TK_LEFT}, + {"LEVEL", 5, TokenType::TK_LEVEL}, + {"LIKE", 4, TokenType::TK_LIKE}, + {"LIMIT", 5, TokenType::TK_LIMIT}, + {"LOAD", 4, TokenType::TK_LOAD}, + {"LOCAL", 5, TokenType::TK_LOCAL}, + {"LOCK", 4, TokenType::TK_LOCK}, + {"MAX", 3, TokenType::TK_MAX}, + {"MIN", 3, TokenType::TK_MIN}, + {"NAMES", 5, TokenType::TK_NAMES}, + {"NATURAL", 7, TokenType::TK_NATURAL}, + {"NOT", 3, TokenType::TK_NOT}, + {"NOTHING", 7, TokenType::TK_NOTHING}, + {"NULL", 4, TokenType::TK_NULL}, + {"OF", 2, TokenType::TK_OF}, + {"OFFSET", 6, TokenType::TK_OFFSET}, + {"ON", 2, TokenType::TK_ON}, + {"ONLY", 4, TokenType::TK_ONLY}, + {"OR", 2, TokenType::TK_OR}, + {"ORDER", 5, TokenType::TK_ORDER}, + {"OUTER", 5, TokenType::TK_OUTER}, + {"PREPARE", 7, TokenType::TK_PREPARE}, + {"READ", 4, TokenType::TK_READ}, + {"REPEATABLE", 10, TokenType::TK_REPEATABLE}, + {"RESET", 5, TokenType::TK_RESET}, + {"RETURNING", 9, TokenType::TK_RETURNING}, + {"REVOKE", 6, TokenType::TK_REVOKE}, + {"RIGHT", 5, TokenType::TK_RIGHT}, + {"ROLLBACK", 8, TokenType::TK_ROLLBACK}, + {"SAVEPOINT", 9, TokenType::TK_SAVEPOINT}, + {"SCHEMA", 6, TokenType::TK_SCHEMA}, + {"SELECT", 6, TokenType::TK_SELECT}, + {"SERIALIZABLE", 12, TokenType::TK_SERIALIZABLE}, + {"SESSION", 7, TokenType::TK_SESSION}, + {"SET", 3, TokenType::TK_SET}, + {"SHARE", 5, TokenType::TK_SHARE}, + {"SHOW", 4, TokenType::TK_SHOW}, + {"START", 5, TokenType::TK_START}, + {"SUM", 3, TokenType::TK_SUM}, + {"TABLE", 5, TokenType::TK_TABLE}, + {"THEN", 4, TokenType::TK_THEN}, + {"TO", 2, TokenType::TK_TO}, + {"TRANSACTION", 11, TokenType::TK_TRANSACTION}, + {"TRUE", 4, TokenType::TK_TRUE}, + {"TRUNCATE", 8, TokenType::TK_TRUNCATE}, + {"UNION", 5, TokenType::TK_UNION}, + {"UNCOMMITTED", 11, TokenType::TK_UNCOMMITTED}, + {"UNLOCK", 6, TokenType::TK_UNLOCK}, + {"UPDATE", 6, TokenType::TK_UPDATE}, + {"USE", 3, TokenType::TK_USE}, + {"USING", 5, TokenType::TK_USING}, + {"VALUES", 6, TokenType::TK_VALUES}, + {"VIEW", 4, TokenType::TK_VIEW}, + {"WHEN", 4, TokenType::TK_WHEN}, + {"WHERE", 5, TokenType::TK_WHERE}, + {"WRITE", 5, TokenType::TK_WRITE}, +}; + +inline constexpr size_t KEYWORD_COUNT = sizeof(KEYWORDS) / sizeof(KEYWORDS[0]); + +inline TokenType lookup(const char* text, uint32_t len) { + size_t lo = 0, hi = KEYWORD_COUNT; + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + int cmp = sql_parser::ci_cmp(text, len, KEYWORDS[mid].text, KEYWORDS[mid].len); + if (cmp == 0) return KEYWORDS[mid].token; + if (cmp < 0) hi = mid; + else lo = mid + 1; + } + return TokenType::TK_IDENTIFIER; +} + +} // namespace pgsql_keywords +} // namespace sql_parser + +#endif // SQL_PARSER_KEYWORDS_PGSQL_H diff --git a/include/sql_parser/parse_result.h b/include/sql_parser/parse_result.h new file mode 100644 index 0000000..41853a7 --- /dev/null +++ b/include/sql_parser/parse_result.h @@ -0,0 +1,55 @@ +#ifndef SQL_PARSER_PARSE_RESULT_H +#define SQL_PARSER_PARSE_RESULT_H + +#include "sql_parser/common.h" +#include "sql_parser/ast.h" + +namespace sql_parser { + +struct ErrorInfo { + uint32_t offset = 0; + StringRef message; +}; + +struct BoundValue { + enum Type : uint8_t { INT, FLOAT, DOUBLE, STRING, BLOB, NULL_VAL, DATETIME, DECIMAL }; + Type type = NULL_VAL; + union { + int64_t int_val; + float float32_val; + double float64_val; + StringRef str_val; + }; + + BoundValue() : type(NULL_VAL), int_val(0) {} + BoundValue(const BoundValue&) = default; + BoundValue& operator=(const BoundValue&) = default; +}; + +struct ParamBindings { + BoundValue* values = nullptr; + uint16_t count = 0; +}; + +struct ParseResult { + enum Status : uint8_t { OK = 0, PARTIAL, ERROR }; + + Status status = ERROR; + StmtType stmt_type = StmtType::UNKNOWN; + AstNode* ast = nullptr; + ErrorInfo error; + StringRef remaining; + + StringRef table_name; + StringRef schema_name; + StringRef database_name; + + ParamBindings bindings; // populated by execute() + + bool ok() const { return status == OK; } + bool has_remaining() const { return !remaining.empty(); } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_PARSE_RESULT_H diff --git a/include/sql_parser/parser.h b/include/sql_parser/parser.h new file mode 100644 index 0000000..7ec5f7c --- /dev/null +++ b/include/sql_parser/parser.h @@ -0,0 +1,90 @@ +#ifndef SQL_PARSER_PARSER_H +#define SQL_PARSER_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/parse_result.h" +#include "sql_parser/stmt_cache.h" + +namespace sql_parser { + +struct ParserConfig { + size_t arena_block_size = 65536; // 64KB + size_t arena_max_size = 1048576; // 1MB + size_t stmt_cache_capacity = 128; +}; + +template +class Parser { +public: + explicit Parser(const ParserConfig& config = {}); + ~Parser() = default; + + // Non-copyable, non-movable + Parser(const Parser&) = delete; + Parser& operator=(const Parser&) = delete; + + // Parse a SQL string. Returns ParseResult with classification + metadata. + // For Tier 1 statements (SELECT, SET), returns PARTIAL until deep parsers + // are implemented (future plan). + ParseResult parse(const char* sql, size_t len); + + // Reset the arena. Call after each query is fully processed. + void reset(); + + // Access the arena (for emitter use) + Arena& arena() { return arena_; } + + // Prepared statement support + ParseResult parse_and_cache(const char* sql, size_t len, uint32_t stmt_id); + ParseResult execute(uint32_t stmt_id, const ParamBindings& params); + void prepare_cache_evict(uint32_t stmt_id); + +private: + Arena arena_; + Tokenizer tokenizer_; + StmtCache stmt_cache_; + + // Classifier: dispatches to the right extractor/parser + ParseResult classify_and_dispatch(); + + // Tier 1 parsers + ParseResult parse_select(); + ParseResult parse_select_from_lparen(); + ParseResult parse_set(); + ParseResult parse_insert(bool is_replace = false); + ParseResult parse_update(); + ParseResult parse_delete(); + + // Tier 2 extractors + ParseResult extract_insert(const Token& first); + ParseResult extract_update(const Token& first); + ParseResult extract_delete(const Token& first); + ParseResult extract_replace(const Token& first); + ParseResult extract_transaction(const Token& first); + ParseResult extract_use(const Token& first); + ParseResult extract_show(const Token& first); + ParseResult extract_prepare(const Token& first); + ParseResult extract_execute(const Token& first); + ParseResult extract_deallocate(const Token& first); + ParseResult extract_ddl(const Token& first); + ParseResult extract_acl(const Token& first); + ParseResult extract_lock(const Token& first); + ParseResult extract_load(const Token& first); + ParseResult extract_reset(const Token& first); + ParseResult extract_unknown(const Token& first); + + // Helpers + // Read optional schema.table or just table. Returns table token. + // If qualified (schema.table), sets schema_out. + Token read_table_name(StringRef& schema_out); + + // Scan forward to semicolon or EOF, set result.remaining + void scan_to_end(ParseResult& result); +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_PARSER_H diff --git a/include/sql_parser/select_parser.h b/include/sql_parser/select_parser.h new file mode 100644 index 0000000..2bd0c90 --- /dev/null +++ b/include/sql_parser/select_parser.h @@ -0,0 +1,351 @@ +#ifndef SQL_PARSER_SELECT_PARSER_H +#define SQL_PARSER_SELECT_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" + +namespace sql_parser { + +template +class SelectParser { +public: + SelectParser(Tokenizer& tokenizer, Arena& arena, bool compound_mode = false) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_), + compound_mode_(compound_mode) {} + + // Parse a SELECT statement (SELECT keyword already consumed by classifier). + // In compound_mode, stops before ORDER BY / LIMIT so they can be claimed + // by the compound query parser. + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_SELECT_STMT); + if (!root) return nullptr; + + // SELECT options: DISTINCT, ALL, SQL_CALC_FOUND_ROWS + AstNode* opts = parse_select_options(); + if (opts) root->add_child(opts); + + // Select item list + AstNode* items = parse_select_item_list(); + if (items) root->add_child(items); + + // INTO (before FROM in some MySQL variants -- skip for now, handle after FROM) + + // FROM clause + if (tok_.peek().type == TokenType::TK_FROM) { + tok_.skip(); + AstNode* from = table_ref_parser_.parse_from_clause(); + if (from) root->add_child(from); + } + + // WHERE clause + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + + // GROUP BY clause + if (tok_.peek().type == TokenType::TK_GROUP) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* group_by = parse_group_by(); + if (group_by) root->add_child(group_by); + } + + // HAVING clause + if (tok_.peek().type == TokenType::TK_HAVING) { + tok_.skip(); + AstNode* having = parse_having(); + if (having) root->add_child(having); + } + + // In compound_mode, stop before ORDER BY / LIMIT so the compound + // query parser can claim them as applying to the compound result. + if (!compound_mode_) { + // ORDER BY clause + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) root->add_child(order_by); + } + + // LIMIT clause + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) root->add_child(limit); + } + + // FOR UPDATE / FOR SHARE (locking) + if (tok_.peek().type == TokenType::TK_FOR) { + AstNode* lock = parse_locking(); + if (lock) root->add_child(lock); + } + + // INTO (MySQL: can appear here too -- INTO OUTFILE/DUMPFILE/var) + if constexpr (D == Dialect::MySQL) { + if (tok_.peek().type == TokenType::TK_INTO) { + AstNode* into = parse_into(); + if (into) root->add_child(into); + } + } + } + + return root; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + bool compound_mode_; + + // ---- SELECT options ---- + + AstNode* parse_select_options() { + AstNode* opts = nullptr; + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_DISTINCT || t.type == TokenType::TK_ALL) { + if (!opts) opts = make_node(arena_, NodeType::NODE_SELECT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else if (t.type == TokenType::TK_SQL_CALC_FOUND_ROWS) { + if (!opts) opts = make_node(arena_, NodeType::NODE_SELECT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else { + break; + } + } + return opts; + } + + // ---- Select item list ---- + + AstNode* parse_select_item_list() { + AstNode* list = make_node(arena_, NodeType::NODE_SELECT_ITEM_LIST); + if (!list) return nullptr; + + while (true) { + AstNode* item = parse_select_item(); + if (!item) break; + list->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return list; + } + + AstNode* parse_select_item() { + AstNode* item = make_node(arena_, NodeType::NODE_SELECT_ITEM); + if (!item) return nullptr; + + AstNode* expr = expr_parser_.parse(); + if (!expr) return nullptr; + item->add_child(expr); + + // Optional alias: AS name, or just name (implicit alias) + Token next = tok_.peek(); + if (next.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + AstNode* alias = make_node(arena_, NodeType::NODE_ALIAS, alias_name.text); + item->add_child(alias); + } else if (TableRefParser::is_alias_start(next.type)) { + // Implicit alias (no AS keyword): SELECT expr alias_name + tok_.skip(); + AstNode* alias = make_node(arena_, NodeType::NODE_ALIAS, next.text); + item->add_child(alias); + } + return item; + } + + // ---- WHERE ---- + + AstNode* parse_where_clause() { + AstNode* where = make_node(arena_, NodeType::NODE_WHERE_CLAUSE); + if (!where) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) where->add_child(expr); + return where; + } + + // ---- GROUP BY ---- + + AstNode* parse_group_by() { + AstNode* group_by = make_node(arena_, NodeType::NODE_GROUP_BY_CLAUSE); + if (!group_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + group_by->add_child(expr); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return group_by; + } + + // ---- HAVING ---- + + AstNode* parse_having() { + AstNode* having = make_node(arena_, NodeType::NODE_HAVING_CLAUSE); + if (!having) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) having->add_child(expr); + return having; + } + + // ---- ORDER BY ---- + + AstNode* parse_order_by() { + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (!order_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + + // Optional ASC/DESC + Token dir = tok_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tok_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + + order_by->add_child(item); + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return order_by; + } + + // ---- LIMIT ---- + + AstNode* parse_limit() { + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (!limit) return nullptr; + + // LIMIT count [OFFSET offset] or LIMIT offset, count (MySQL) + AstNode* first = expr_parser_.parse(); + if (first) limit->add_child(first); + + if (tok_.peek().type == TokenType::TK_OFFSET) { + tok_.skip(); + AstNode* offset = expr_parser_.parse(); + if (offset) limit->add_child(offset); + } else if (tok_.peek().type == TokenType::TK_COMMA) { + // MySQL: LIMIT offset, count + tok_.skip(); + AstNode* count = expr_parser_.parse(); + if (count) limit->add_child(count); + } + + if constexpr (D == Dialect::PostgreSQL) { + // PostgreSQL also supports FETCH FIRST N ROWS ONLY after LIMIT/OFFSET + // We handle OFFSET here too since PgSQL uses LIMIT x OFFSET y + } + + return limit; + } + + // ---- FOR UPDATE / FOR SHARE ---- + + AstNode* parse_locking() { + AstNode* lock = make_node(arena_, NodeType::NODE_LOCKING_CLAUSE); + if (!lock) return nullptr; + + tok_.skip(); // consume FOR + Token strength = tok_.next_token(); // UPDATE or SHARE + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, strength.text)); + + // Optional: OF table_list + if (tok_.peek().type == TokenType::TK_OF) { + tok_.skip(); + while (true) { + Token table = tok_.next_token(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, table.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + + // Optional: NOWAIT or SKIP LOCKED + if (tok_.peek().type == TokenType::TK_NOWAIT) { + tok_.skip(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"NOWAIT", 6})); + } else if (tok_.peek().type == TokenType::TK_SKIP) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LOCKED) tok_.skip(); + lock->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"SKIP LOCKED", 11})); + } + + return lock; + } + + // ---- INTO (MySQL: INTO OUTFILE/DUMPFILE/@var) ---- + + AstNode* parse_into() { + AstNode* into = make_node(arena_, NodeType::NODE_INTO_CLAUSE); + if (!into) return nullptr; + + tok_.skip(); // consume INTO + Token t = tok_.peek(); + + if (t.type == TokenType::TK_OUTFILE) { + tok_.skip(); + Token filename = tok_.next_token(); + into->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, + StringRef{"OUTFILE", 7})); + into->add_child(make_node(arena_, NodeType::NODE_LITERAL_STRING, filename.text)); + } else if (t.type == TokenType::TK_DUMPFILE) { + tok_.skip(); + Token filename = tok_.next_token(); + into->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, + StringRef{"DUMPFILE", 8})); + into->add_child(make_node(arena_, NodeType::NODE_LITERAL_STRING, filename.text)); + } else { + // INTO @var1, @var2, ... + while (true) { + AstNode* var = expr_parser_.parse(); + if (var) into->add_child(var); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + } + + return into; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_SELECT_PARSER_H diff --git a/include/sql_parser/set_parser.h b/include/sql_parser/set_parser.h new file mode 100644 index 0000000..6c909e9 --- /dev/null +++ b/include/sql_parser/set_parser.h @@ -0,0 +1,245 @@ +#ifndef SQL_PARSER_SET_PARSER_H +#define SQL_PARSER_SET_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class SetParser { +public: + SetParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena) {} + + // Parse a SET statement (SET keyword already consumed by classifier). + // Returns the root NODE_SET_STMT node, or nullptr on failure. + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_SET_STMT); + if (!root) return nullptr; + + Token next = tok_.peek(); + + // SET NAMES ... + if (next.type == TokenType::TK_NAMES) { + tok_.skip(); + AstNode* names_node = parse_set_names(); + if (names_node) root->add_child(names_node); + return root; + } + + // SET CHARACTER SET ... or SET CHARSET ... + if (next.type == TokenType::TK_CHARACTER) { + tok_.skip(); + // Expect SET keyword + if (tok_.peek().type == TokenType::TK_SET) { + tok_.skip(); + } + AstNode* charset_node = parse_set_charset(); + if (charset_node) root->add_child(charset_node); + return root; + } + if (next.type == TokenType::TK_CHARSET) { + tok_.skip(); + AstNode* charset_node = parse_set_charset(); + if (charset_node) root->add_child(charset_node); + return root; + } + + // SET [GLOBAL|SESSION] TRANSACTION ... + // Need to check for scope + TRANSACTION or just TRANSACTION + if (next.type == TokenType::TK_TRANSACTION) { + tok_.skip(); + AstNode* txn_node = parse_set_transaction(StringRef{}); + if (txn_node) root->add_child(txn_node); + return root; + } + + if (next.type == TokenType::TK_GLOBAL || next.type == TokenType::TK_SESSION) { + Token scope_tok = tok_.next_token(); + if (tok_.peek().type == TokenType::TK_TRANSACTION) { + tok_.skip(); + AstNode* txn_node = parse_set_transaction(scope_tok.text); + if (txn_node) root->add_child(txn_node); + return root; + } + // Not TRANSACTION — it's SET GLOBAL var = expr + // Fall through to variable assignment with scope + AstNode* assignment = parse_variable_assignment(&scope_tok); + if (assignment) root->add_child(assignment); + // Parse remaining comma-separated assignments + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_assign = parse_variable_assignment(nullptr); + if (next_assign) root->add_child(next_assign); + } + return root; + } + + // PostgreSQL: SET LOCAL var = expr + if constexpr (D == Dialect::PostgreSQL) { + if (next.type == TokenType::TK_LOCAL) { + Token scope_tok = tok_.next_token(); + AstNode* assignment = parse_variable_assignment(&scope_tok); + if (assignment) root->add_child(assignment); + return root; + } + } + + // SET var = expr [, var = expr, ...] + AstNode* assignment = parse_variable_assignment(nullptr); + if (assignment) root->add_child(assignment); + while (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + AstNode* next_assign = parse_variable_assignment(nullptr); + if (next_assign) root->add_child(next_assign); + } + + return root; + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + + // SET NAMES charset [COLLATE collation] + AstNode* parse_set_names() { + AstNode* node = make_node(arena_, NodeType::NODE_SET_NAMES); + if (!node) return nullptr; + + // charset name or DEFAULT + Token charset = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text)); + + // Optional COLLATE + if (tok_.peek().type == TokenType::TK_COLLATE) { + tok_.skip(); + Token collation = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, collation.text)); + } + return node; + } + + // SET CHARACTER SET charset / SET CHARSET charset + AstNode* parse_set_charset() { + AstNode* node = make_node(arena_, NodeType::NODE_SET_CHARSET); + if (!node) return nullptr; + + Token charset = tok_.next_token(); + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, charset.text)); + return node; + } + + // SET [GLOBAL|SESSION] TRANSACTION ... + AstNode* parse_set_transaction(StringRef scope) { + AstNode* node = make_node(arena_, NodeType::NODE_SET_TRANSACTION); + if (!node) return nullptr; + + if (!scope.empty()) { + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope)); + } + + // ISOLATION LEVEL ... or READ ONLY/WRITE + Token next = tok_.peek(); + if (next.type == TokenType::TK_ISOLATION) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LEVEL) tok_.skip(); + + // READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE + Token level = tok_.next_token(); + if (level.type == TokenType::TK_READ) { + Token sublevel = tok_.next_token(); + // Combine "READ COMMITTED" or "READ UNCOMMITTED" + StringRef combined{level.text.ptr, + static_cast((sublevel.text.ptr + sublevel.text.len) - level.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } else if (level.type == TokenType::TK_REPEATABLE) { + Token read_tok = tok_.next_token(); // READ + StringRef combined{level.text.ptr, + static_cast((read_tok.text.ptr + read_tok.text.len) - level.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } else { + // SERIALIZABLE + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, level.text)); + } + } else if (next.type == TokenType::TK_READ) { + tok_.skip(); + Token rw = tok_.next_token(); // ONLY or WRITE + StringRef combined{next.text.ptr, + static_cast((rw.text.ptr + rw.text.len) - next.text.ptr)}; + node->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, combined)); + } + + return node; + } + + // Parse a single variable assignment: [scope] target = expr + // scope_token is non-null if GLOBAL/SESSION/LOCAL was already consumed + AstNode* parse_variable_assignment(const Token* scope_token) { + AstNode* assignment = make_node(arena_, NodeType::NODE_VAR_ASSIGNMENT); + if (!assignment) return nullptr; + + // Build the variable target + AstNode* target = make_node(arena_, NodeType::NODE_VAR_TARGET); + if (!target) return nullptr; + + if (scope_token) { + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, scope_token->text)); + } + + Token var = tok_.peek(); + if (var.type == TokenType::TK_AT) { + // User variable @name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{var.text.ptr, + static_cast((name.text.ptr + name.text.len) - var.text.ptr)}; + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full)); + } else if (var.type == TokenType::TK_DOUBLE_AT) { + // System variable @@[scope.]name + tok_.skip(); + Token name = tok_.next_token(); + StringRef full{var.text.ptr, + static_cast((name.text.ptr + name.text.len) - var.text.ptr)}; + // Check for @@scope.name + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token actual_name = tok_.next_token(); + full = StringRef{var.text.ptr, + static_cast((actual_name.text.ptr + actual_name.text.len) - var.text.ptr)}; + } + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, full)); + } else { + // Plain variable name + Token name = tok_.next_token(); + target->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + } + + assignment->add_child(target); + + // Expect = or := (MySQL) or TO (PostgreSQL) + Token eq = tok_.peek(); + if (eq.type == TokenType::TK_EQUAL || eq.type == TokenType::TK_COLON_EQUAL) { + tok_.skip(); + } else if constexpr (D == Dialect::PostgreSQL) { + if (eq.type == TokenType::TK_TO) { + tok_.skip(); + } + } + + // Parse RHS expression + AstNode* rhs = expr_parser_.parse(); + if (rhs) assignment->add_child(rhs); + + return assignment; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_SET_PARSER_H diff --git a/include/sql_parser/stmt_cache.h b/include/sql_parser/stmt_cache.h new file mode 100644 index 0000000..11182bc --- /dev/null +++ b/include/sql_parser/stmt_cache.h @@ -0,0 +1,177 @@ +#ifndef SQL_PARSER_STMT_CACHE_H +#define SQL_PARSER_STMT_CACHE_H + +#include "sql_parser/ast.h" +#include "sql_parser/common.h" +#include "sql_parser/parse_result.h" +#include +#include +#include +#include + +namespace sql_parser { + +// Deep-copy an AST tree from arena to heap memory. +// The returned tree must be freed with free_ast(). +inline AstNode* deep_copy_ast(const AstNode* src) { + if (!src) return nullptr; + + AstNode* dst = static_cast(std::malloc(sizeof(AstNode))); + if (!dst) return nullptr; + + dst->type = src->type; + dst->flags = src->flags; + dst->first_child = nullptr; + dst->next_sibling = nullptr; + + // Deep-copy value string to heap + if (src->value_ptr && src->value_len > 0) { + char* val_copy = static_cast(std::malloc(src->value_len)); + if (val_copy) { + std::memcpy(val_copy, src->value_ptr, src->value_len); + } + dst->value_ptr = val_copy; + dst->value_len = src->value_len; + } else { + dst->value_ptr = nullptr; + dst->value_len = 0; + } + + // Recursively copy children + const AstNode* src_child = src->first_child; + AstNode* prev_dst_child = nullptr; + while (src_child) { + AstNode* dst_child = deep_copy_ast(src_child); + if (dst_child) { + if (!dst->first_child) { + dst->first_child = dst_child; + } else if (prev_dst_child) { + prev_dst_child->next_sibling = dst_child; + } + prev_dst_child = dst_child; + } + src_child = src_child->next_sibling; + } + + return dst; +} + +// Free a heap-allocated AST tree (produced by deep_copy_ast). +inline void free_ast(AstNode* node) { + if (!node) return; + // Free children first + AstNode* child = node->first_child; + while (child) { + AstNode* next = child->next_sibling; + free_ast(child); + child = next; + } + // Free value string + if (node->value_ptr) { + std::free(const_cast(node->value_ptr)); + } + std::free(node); +} + +// Cached entry for a prepared statement. +struct CachedStmt { + uint32_t stmt_id; + StmtType stmt_type; + AstNode* ast; // heap-allocated deep copy + + ~CachedStmt() { + free_ast(ast); + } + + // Non-copyable + CachedStmt(const CachedStmt&) = delete; + CachedStmt& operator=(const CachedStmt&) = delete; + CachedStmt(CachedStmt&& o) noexcept + : stmt_id(o.stmt_id), stmt_type(o.stmt_type), ast(o.ast) { + o.ast = nullptr; + } + CachedStmt& operator=(CachedStmt&& o) noexcept { + if (this != &o) { + free_ast(ast); + stmt_id = o.stmt_id; + stmt_type = o.stmt_type; + ast = o.ast; + o.ast = nullptr; + } + return *this; + } + + CachedStmt() : stmt_id(0), stmt_type(StmtType::UNKNOWN), ast(nullptr) {} + CachedStmt(uint32_t id, StmtType type, AstNode* a) + : stmt_id(id), stmt_type(type), ast(a) {} +}; + +// Fixed-capacity LRU cache for prepared statements. +class StmtCache { +public: + explicit StmtCache(size_t capacity = 128) : capacity_(capacity) {} + + ~StmtCache() { clear(); } + + // Non-copyable + StmtCache(const StmtCache&) = delete; + StmtCache& operator=(const StmtCache&) = delete; + + // Store a prepared statement. Deep-copies the AST from the arena. + // Evicts LRU entry if at capacity. + bool store(uint32_t stmt_id, StmtType stmt_type, const AstNode* ast) { + // If already exists, remove old entry + evict(stmt_id); + + AstNode* copy = deep_copy_ast(ast); + if (!copy && ast) return false; + + // Evict LRU if at capacity + if (lru_.size() >= capacity_) { + auto& oldest = lru_.back(); + map_.erase(oldest.stmt_id); + lru_.pop_back(); + } + + lru_.emplace_front(stmt_id, stmt_type, copy); + map_[stmt_id] = lru_.begin(); + return true; + } + + // Look up a cached statement. Returns nullptr if not found. + // Moves the entry to front of LRU. + const CachedStmt* lookup(uint32_t stmt_id) { + auto it = map_.find(stmt_id); + if (it == map_.end()) return nullptr; + // Move to front (most recently used) + lru_.splice(lru_.begin(), lru_, it->second); + return &(*it->second); + } + + // Evict a specific statement. + void evict(uint32_t stmt_id) { + auto it = map_.find(stmt_id); + if (it != map_.end()) { + lru_.erase(it->second); + map_.erase(it); + } + } + + // Clear all entries. + void clear() { + lru_.clear(); + map_.clear(); + } + + size_t size() const { return map_.size(); } + size_t capacity() const { return capacity_; } + +private: + size_t capacity_; + std::list lru_; + std::unordered_map::iterator> map_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_STMT_CACHE_H diff --git a/include/sql_parser/string_builder.h b/include/sql_parser/string_builder.h new file mode 100644 index 0000000..be4311c --- /dev/null +++ b/include/sql_parser/string_builder.h @@ -0,0 +1,83 @@ +#ifndef SQL_PARSER_STRING_BUILDER_H +#define SQL_PARSER_STRING_BUILDER_H + +#include "sql_parser/common.h" +#include "sql_parser/arena.h" +#include + +namespace sql_parser { + +// Arena-backed string builder for emitting SQL. +// Builds a string by appending chunks. The final result is a contiguous +// StringRef obtained via finish(). All memory is arena-allocated. +class StringBuilder { +public: + explicit StringBuilder(Arena& arena, size_t initial_capacity = 1024) + : arena_(arena), capacity_(initial_capacity), len_(0) { + buf_ = static_cast(arena_.allocate(capacity_)); + } + + void append(const char* s, size_t n) { + ensure_capacity(n); + if (buf_) { + std::memcpy(buf_ + len_, s, n); + len_ += n; + } + } + + void append(StringRef ref) { + if (ref.ptr && ref.len > 0) { + append(ref.ptr, ref.len); + } + } + + void append(const char* s) { + append(s, std::strlen(s)); + } + + void append_char(char c) { + ensure_capacity(1); + if (buf_) { + buf_[len_++] = c; + } + } + + // Append a space if the last character isn't already a space + void space() { + if (len_ > 0 && buf_[len_ - 1] != ' ') { + append_char(' '); + } + } + + StringRef finish() { + return StringRef{buf_, static_cast(len_)}; + } + + size_t length() const { return len_; } + +private: + Arena& arena_; + char* buf_; + size_t capacity_; + size_t len_; + + void ensure_capacity(size_t additional) { + if (!buf_) return; + if (len_ + additional <= capacity_) return; + + size_t new_cap = capacity_ * 2; + while (new_cap < len_ + additional) new_cap *= 2; + + char* new_buf = static_cast(arena_.allocate(new_cap)); + if (new_buf) { + std::memcpy(new_buf, buf_, len_); + } + buf_ = new_buf; + capacity_ = new_cap; + // Old buffer is abandoned in the arena — freed on arena reset + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_STRING_BUILDER_H diff --git a/include/sql_parser/table_ref_parser.h b/include/sql_parser/table_ref_parser.h new file mode 100644 index 0000000..153d4d6 --- /dev/null +++ b/include/sql_parser/table_ref_parser.h @@ -0,0 +1,241 @@ +#ifndef SQL_PARSER_TABLE_REF_PARSER_H +#define SQL_PARSER_TABLE_REF_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" + +namespace sql_parser { + +template +class TableRefParser { +public: + TableRefParser(Tokenizer& tokenizer, Arena& arena, + ExpressionParser& expr_parser) + : tok_(tokenizer), arena_(arena), expr_parser_(expr_parser) {} + + // Parse a FROM clause: table_ref [, table_ref | JOIN ...]* + AstNode* parse_from_clause() { + AstNode* from = make_node(arena_, NodeType::NODE_FROM_CLAUSE); + if (!from) return nullptr; + + // First table reference + AstNode* table_ref = parse_table_reference(); + if (table_ref) from->add_child(table_ref); + + // Additional table refs (comma join) or explicit JOINs + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_COMMA) { + // Comma join: FROM t1, t2 + tok_.skip(); + AstNode* next_ref = parse_table_reference(); + if (next_ref) from->add_child(next_ref); + } else if (is_join_start(t.type)) { + // Explicit JOIN + AstNode* join = parse_join(from->first_child); + if (join) { + from->add_child(join); + } + } else { + break; + } + } + + return from; + } + + // Parse a single table reference (simple name, qualified name, subquery) + AstNode* parse_table_reference() { + Token t = tok_.peek(); + + // Subquery: (SELECT ...) + if (t.type == TokenType::TK_LPAREN) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_SELECT) { + AstNode* subq = make_node(arena_, NodeType::NODE_SUBQUERY); + // Skip to matching paren + int depth = 1; + while (depth > 0) { + Token st = tok_.next_token(); + if (st.type == TokenType::TK_LPAREN) ++depth; + else if (st.type == TokenType::TK_RPAREN) --depth; + else if (st.type == TokenType::TK_EOF) break; + } + // Optional alias + AstNode* ref = make_node(arena_, NodeType::NODE_TABLE_REF); + ref->add_child(subq); + parse_optional_alias(ref); + return ref; + } + // Parenthesized table reference -- parse inner + AstNode* inner = parse_table_reference(); + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + return inner; + } + + // Simple table name or schema.table + AstNode* ref = make_node(arena_, NodeType::NODE_TABLE_REF); + Token name = tok_.next_token(); + + if (tok_.peek().type == TokenType::TK_DOT) { + // Qualified: schema.table + tok_.skip(); + Token table_name = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, table_name.text)); + ref->add_child(qname); + } else { + ref->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, name.text)); + } + + // Optional alias + parse_optional_alias(ref); + return ref; + } + + // Parse a JOIN clause + AstNode* parse_join(AstNode* /* left_ref */) { + AstNode* join = make_node(arena_, NodeType::NODE_JOIN_CLAUSE); + if (!join) return nullptr; + + // Consume join type tokens + Token t = tok_.peek(); + StringRef join_type_start = t.text; + StringRef join_type_end = t.text; + + // Optional: NATURAL, LEFT, RIGHT, FULL, INNER, OUTER, CROSS + while (t.type == TokenType::TK_NATURAL || t.type == TokenType::TK_LEFT || + t.type == TokenType::TK_RIGHT || t.type == TokenType::TK_FULL || + t.type == TokenType::TK_INNER || t.type == TokenType::TK_OUTER || + t.type == TokenType::TK_CROSS) { + tok_.skip(); + join_type_end = t.text; + t = tok_.peek(); + } + + // Expect JOIN keyword + if (t.type == TokenType::TK_JOIN) { + join_type_end = t.text; + tok_.skip(); + } + + // Set join type as value (covers the span from first modifier to JOIN) + StringRef join_type{join_type_start.ptr, + static_cast((join_type_end.ptr + join_type_end.len) - join_type_start.ptr)}; + join->value_ptr = join_type.ptr; + join->value_len = join_type.len; + + // Right table reference + AstNode* right_ref = parse_table_reference(); + if (right_ref) join->add_child(right_ref); + + // Join condition: ON expr or USING (col_list) + if (tok_.peek().type == TokenType::TK_ON) { + tok_.skip(); + AstNode* on_expr = expr_parser_.parse(); + if (on_expr) join->add_child(on_expr); + } else if (tok_.peek().type == TokenType::TK_USING) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_LPAREN) { + tok_.skip(); + AstNode* using_list = make_node(arena_, NodeType::NODE_IDENTIFIER, StringRef{"USING", 5}); + while (true) { + Token col = tok_.next_token(); + using_list->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + if (tok_.peek().type == TokenType::TK_RPAREN) tok_.skip(); + join->add_child(using_list); + } + } + + return join; + } + + // Parse optional alias (AS name or implicit alias) + void parse_optional_alias(AstNode* parent) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + parent->add_child(make_node(arena_, NodeType::NODE_ALIAS, alias_name.text)); + } else if (is_alias_start(t.type)) { + tok_.skip(); + parent->add_child(make_node(arena_, NodeType::NODE_ALIAS, t.text)); + } + } + + // Check if a token can start a JOIN + static bool is_join_start(TokenType type) { + return type == TokenType::TK_JOIN || type == TokenType::TK_INNER || + type == TokenType::TK_LEFT || type == TokenType::TK_RIGHT || + type == TokenType::TK_FULL || type == TokenType::TK_OUTER || + type == TokenType::TK_CROSS || type == TokenType::TK_NATURAL; + } + + // Check if a token can start an implicit alias (identifier-like, not a clause keyword) + static bool is_alias_start(TokenType type) { + if (type == TokenType::TK_IDENTIFIER) return true; + // Some keywords are NOT valid alias starts because they start clauses + switch (type) { + case TokenType::TK_FROM: + case TokenType::TK_WHERE: + case TokenType::TK_GROUP: + case TokenType::TK_HAVING: + case TokenType::TK_ORDER: + case TokenType::TK_LIMIT: + case TokenType::TK_FOR: + case TokenType::TK_INTO: + case TokenType::TK_JOIN: + case TokenType::TK_INNER: + case TokenType::TK_LEFT: + case TokenType::TK_RIGHT: + case TokenType::TK_FULL: + case TokenType::TK_OUTER: + case TokenType::TK_CROSS: + case TokenType::TK_NATURAL: + case TokenType::TK_ON: + case TokenType::TK_USING: + case TokenType::TK_UNION: + case TokenType::TK_INTERSECT: + case TokenType::TK_EXCEPT: + case TokenType::TK_SEMICOLON: + case TokenType::TK_RPAREN: + case TokenType::TK_LPAREN: + case TokenType::TK_EOF: + case TokenType::TK_COMMA: + case TokenType::TK_SET: + case TokenType::TK_LOCK: + case TokenType::TK_UNLOCK: + case TokenType::TK_VALUES: + case TokenType::TK_SELECT: + case TokenType::TK_DEFAULT: + case TokenType::TK_RETURNING: + case TokenType::TK_CONFLICT: + case TokenType::TK_DO: + case TokenType::TK_NOTHING: + case TokenType::TK_DUPLICATE: + return false; + default: + return true; // Keywords not in the blocklist can be implicit aliases + } + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser& expr_parser_; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TABLE_REF_PARSER_H diff --git a/include/sql_parser/token.h b/include/sql_parser/token.h new file mode 100644 index 0000000..4ea91c9 --- /dev/null +++ b/include/sql_parser/token.h @@ -0,0 +1,92 @@ +#ifndef SQL_PARSER_TOKEN_H +#define SQL_PARSER_TOKEN_H + +#include "sql_parser/common.h" +#include + +namespace sql_parser { + +enum class TokenType : uint16_t { + TK_EOF = 0, + TK_ERROR, + TK_IDENTIFIER, + TK_INTEGER, + TK_FLOAT, + TK_STRING, + TK_LPAREN, + TK_RPAREN, + TK_COMMA, + TK_SEMICOLON, + TK_DOT, + TK_ASTERISK, + TK_PLUS, + TK_MINUS, + TK_SLASH, + TK_PERCENT, + TK_EQUAL, + TK_NOT_EQUAL, + TK_LESS, + TK_GREATER, + TK_LESS_EQUAL, + TK_GREATER_EQUAL, + TK_AMPERSAND, + TK_PIPE, + TK_CARET, + TK_TILDE, + TK_EXCLAIM, + TK_COLON, + TK_QUESTION, + TK_AT, + TK_DOUBLE_AT, + TK_HASH, + TK_COLON_EQUAL, + TK_DOUBLE_PIPE, + TK_DOUBLE_COLON, + TK_DOLLAR_NUM, + TK_SELECT, TK_INSERT, TK_UPDATE, TK_DELETE, TK_REPLACE, + TK_FROM, TK_WHERE, TK_SET, TK_INTO, TK_VALUES, TK_AS, TK_ON, TK_USING, + TK_JOIN, TK_INNER, TK_LEFT, TK_RIGHT, TK_FULL, TK_OUTER, TK_CROSS, TK_NATURAL, + TK_ORDER, TK_BY, TK_GROUP, TK_HAVING, TK_LIMIT, TK_OFFSET, TK_FETCH, + TK_ASC, TK_DESC, TK_DISTINCT, TK_ALL, + TK_AND, TK_OR, TK_NOT, TK_IS, TK_NULL, TK_IN, TK_BETWEEN, TK_LIKE, TK_EXISTS, + TK_CASE, TK_WHEN, TK_THEN, TK_ELSE, TK_END, TK_TRUE, TK_FALSE, + TK_NAMES, TK_CHARACTER, TK_CHARSET, TK_COLLATE, TK_GLOBAL, TK_SESSION, TK_LOCAL, + TK_PERSIST, TK_DEFAULT, TK_TRANSACTION, TK_ISOLATION, TK_LEVEL, + TK_READ, TK_WRITE, TK_ONLY, TK_COMMITTED, TK_UNCOMMITTED, TK_REPEATABLE, + TK_SERIALIZABLE, TK_TO, + TK_CREATE, TK_ALTER, TK_DROP, TK_TRUNCATE, TK_TABLE, TK_INDEX, TK_VIEW, + TK_DATABASE, TK_SCHEMA, TK_IF, + TK_BEGIN, TK_START, TK_COMMIT, TK_ROLLBACK, TK_SAVEPOINT, + TK_USE, TK_SHOW, TK_PREPARE, TK_EXECUTE, TK_DEALLOCATE, + TK_GRANT, TK_REVOKE, TK_LOCK, TK_UNLOCK, TK_LOAD, TK_DATA, + TK_FOR, TK_SHARE, TK_NOWAIT, TK_SKIP, TK_LOCKED, + TK_OUTFILE, TK_DUMPFILE, TK_IGNORE, TK_LOW_PRIORITY, TK_QUICK, TK_RESET, + TK_UNION, TK_OF, + TK_SQL_CALC_FOUND_ROWS, + TK_COUNT, TK_SUM, TK_AVG, TK_MIN, TK_MAX, + + // INSERT/REPLACE related tokens + TK_DELAYED, + TK_HIGH_PRIORITY, + TK_DUPLICATE, + TK_KEY, + TK_CONFLICT, + TK_DO, + TK_NOTHING, + TK_RETURNING, + TK_CONSTRAINT, + + // Compound query operators + TK_INTERSECT, + TK_EXCEPT, +}; + +struct Token { + TokenType type = TokenType::TK_EOF; + StringRef text; + uint32_t offset = 0; +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TOKEN_H diff --git a/include/sql_parser/tokenizer.h b/include/sql_parser/tokenizer.h new file mode 100644 index 0000000..6106d89 --- /dev/null +++ b/include/sql_parser/tokenizer.h @@ -0,0 +1,369 @@ +#ifndef SQL_PARSER_TOKENIZER_H +#define SQL_PARSER_TOKENIZER_H + +#include "sql_parser/token.h" +#include "sql_parser/keywords_mysql.h" +#include "sql_parser/keywords_pgsql.h" + +namespace sql_parser { + +template +class Tokenizer { +public: + void reset(const char* input, size_t len) { + start_ = input; + cursor_ = input; + end_ = input + len; + has_peeked_ = false; + } + + Token next_token() { + if (has_peeked_) { + has_peeked_ = false; + return peeked_; + } + return scan_token(); + } + + Token peek() { + if (!has_peeked_) { + peeked_ = scan_token(); + has_peeked_ = true; + } + return peeked_; + } + + void skip() { + if (has_peeked_) { + has_peeked_ = false; + } else { + scan_token(); + } + } + + // Expose end of input for remaining-input calculation + const char* input_end() const { return end_; } + +private: + const char* start_ = nullptr; + const char* cursor_ = nullptr; + const char* end_ = nullptr; + Token peeked_; + bool has_peeked_ = false; + + uint32_t offset() const { + return static_cast(cursor_ - start_); + } + + char current() const { return (cursor_ < end_) ? *cursor_ : '\0'; } + char advance() { + char c = current(); + if (cursor_ < end_) ++cursor_; + return c; + } + char peek_char(size_t ahead = 0) const { + const char* p = cursor_ + ahead; + return (p < end_) ? *p : '\0'; + } + + void skip_whitespace_and_comments() { + while (cursor_ < end_) { + char c = *cursor_; + + // Whitespace + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + ++cursor_; + continue; + } + + // -- line comment (MySQL requires space after --, PgSQL doesn't but we handle both) + if (c == '-' && peek_char(1) == '-') { + cursor_ += 2; + while (cursor_ < end_ && *cursor_ != '\n') ++cursor_; + continue; + } + + // # line comment (MySQL only) + if constexpr (D == Dialect::MySQL) { + if (c == '#') { + ++cursor_; + while (cursor_ < end_ && *cursor_ != '\n') ++cursor_; + continue; + } + } + + // /* block comment */ + if (c == '/' && peek_char(1) == '*') { + cursor_ += 2; + if constexpr (D == Dialect::PostgreSQL) { + // PostgreSQL supports nested block comments + int depth = 1; + while (cursor_ < end_ && depth > 0) { + if (*cursor_ == '/' && peek_char(1) == '*') { + ++depth; + cursor_ += 2; + } else if (*cursor_ == '*' && peek_char(1) == '/') { + --depth; + cursor_ += 2; + } else { + ++cursor_; + } + } + } else { + // MySQL: no nesting + while (cursor_ < end_) { + if (*cursor_ == '*' && peek_char(1) == '/') { + cursor_ += 2; + break; + } + ++cursor_; + } + } + continue; + } + + break; // not whitespace or comment + } + } + + Token make_token(TokenType type, const char* start, uint32_t len) { + return Token{type, StringRef{start, len}, + static_cast(start - start_)}; + } + + Token scan_identifier_or_keyword() { + const char* start = cursor_; + while (cursor_ < end_) { + char c = *cursor_; + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_') { + ++cursor_; + } else { + break; + } + } + uint32_t len = static_cast(cursor_ - start); + + // Keyword lookup + TokenType kw; + if constexpr (D == Dialect::MySQL) { + kw = mysql_keywords::lookup(start, len); + } else { + kw = pgsql_keywords::lookup(start, len); + } + return make_token(kw, start, len); + } + + Token scan_number() { + const char* start = cursor_; + bool has_dot = false; + while (cursor_ < end_) { + char c = *cursor_; + if (c >= '0' && c <= '9') { + ++cursor_; + } else if (c == '.' && !has_dot) { + has_dot = true; + ++cursor_; + } else { + break; + } + } + uint32_t len = static_cast(cursor_ - start); + return make_token(has_dot ? TokenType::TK_FLOAT : TokenType::TK_INTEGER, + start, len); + } + + Token scan_single_quoted_string() { + ++cursor_; // skip opening quote + const char* content_start = cursor_; + while (cursor_ < end_) { + if (*cursor_ == '\'') { + // Check for doubled single-quote escape ('') + if (cursor_ + 1 < end_ && *(cursor_ + 1) == '\'') { + cursor_ += 2; // skip both quotes + continue; + } + break; // end of string + } + if (*cursor_ == '\\') { + ++cursor_; // skip escaped char + if (cursor_ < end_) ++cursor_; + } else { + ++cursor_; + } + } + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing quote + return make_token(TokenType::TK_STRING, content_start, len); + } + + // MySQL: backtick-quoted identifier + Token scan_backtick_identifier() { + ++cursor_; // skip opening backtick + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '`') ++cursor_; + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing backtick + return make_token(TokenType::TK_IDENTIFIER, content_start, len); + } + + // PostgreSQL: double-quoted identifier + Token scan_double_quoted_identifier() { + ++cursor_; // skip opening quote + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '"') ++cursor_; + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; // skip closing quote + return make_token(TokenType::TK_IDENTIFIER, content_start, len); + } + + // PostgreSQL: $$...$$ dollar-quoted string + Token scan_dollar_string() { + // We're at the first $. Simple form: $$content$$ + cursor_ += 2; // skip opening $$ + const char* content_start = cursor_; + while (cursor_ < end_) { + if (*cursor_ == '$' && peek_char(1) == '$') { + uint32_t len = static_cast(cursor_ - content_start); + cursor_ += 2; // skip closing $$ + return make_token(TokenType::TK_STRING, content_start, len); + } + ++cursor_; + } + // Unterminated — return what we have + uint32_t len = static_cast(cursor_ - content_start); + return make_token(TokenType::TK_STRING, content_start, len); + } + + Token scan_token() { + skip_whitespace_and_comments(); + + if (cursor_ >= end_) { + return make_token(TokenType::TK_EOF, cursor_, 0); + } + + char c = *cursor_; + + // Identifiers and keywords + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return scan_identifier_or_keyword(); + } + + // Numbers + if (c >= '0' && c <= '9') { + return scan_number(); + } + + // Dot — could be start of .123 float or just dot + if (c == '.' && cursor_ + 1 < end_ && + peek_char(1) >= '0' && peek_char(1) <= '9') { + return scan_number(); + } + + // String literals + if (c == '\'') return scan_single_quoted_string(); + + // MySQL: double-quoted strings; PostgreSQL: double-quoted identifiers + if (c == '"') { + if constexpr (D == Dialect::MySQL) { + // In MySQL, double quotes are strings (unless ANSI_QUOTES mode) + ++cursor_; + const char* content_start = cursor_; + while (cursor_ < end_ && *cursor_ != '"') { + if (*cursor_ == '\\') { ++cursor_; if (cursor_ < end_) ++cursor_; } + else ++cursor_; + } + uint32_t len = static_cast(cursor_ - content_start); + if (cursor_ < end_) ++cursor_; + return make_token(TokenType::TK_STRING, content_start, len); + } else { + return scan_double_quoted_identifier(); + } + } + + // Backtick identifier (MySQL only) + if constexpr (D == Dialect::MySQL) { + if (c == '`') return scan_backtick_identifier(); + } + + // @ and @@ + if (c == '@') { + if (peek_char(1) == '@') { + const char* s = cursor_; + cursor_ += 2; + return make_token(TokenType::TK_DOUBLE_AT, s, 2); + } + const char* s = cursor_; + ++cursor_; + return make_token(TokenType::TK_AT, s, 1); + } + + // $ — PostgreSQL: $N placeholder or $$string$$ + if constexpr (D == Dialect::PostgreSQL) { + if (c == '$') { + if (peek_char(1) == '$') { + return scan_dollar_string(); + } + if (peek_char(1) >= '0' && peek_char(1) <= '9') { + const char* start = cursor_; + ++cursor_; // skip $ + while (cursor_ < end_ && *cursor_ >= '0' && *cursor_ <= '9') + ++cursor_; + uint32_t len = static_cast(cursor_ - start); + return make_token(TokenType::TK_DOLLAR_NUM, start, len); + } + } + } + + // Two-character operators + if (cursor_ + 1 < end_) { + char c2 = peek_char(1); + + if (c == '<' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_LESS_EQUAL, s, 2); } + if (c == '>' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_GREATER_EQUAL, s, 2); } + if (c == '!' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_NOT_EQUAL, s, 2); } + if (c == '<' && c2 == '>') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_NOT_EQUAL, s, 2); } + if (c == '|' && c2 == '|') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_DOUBLE_PIPE, s, 2); } + + if constexpr (D == Dialect::MySQL) { + if (c == ':' && c2 == '=') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_COLON_EQUAL, s, 2); } + } + + if constexpr (D == Dialect::PostgreSQL) { + if (c == ':' && c2 == ':') { auto s = cursor_; cursor_ += 2; return make_token(TokenType::TK_DOUBLE_COLON, s, 2); } + } + } + + // Single-character operators/punctuation + const char* s = cursor_; + ++cursor_; + switch (c) { + case '(': return make_token(TokenType::TK_LPAREN, s, 1); + case ')': return make_token(TokenType::TK_RPAREN, s, 1); + case ',': return make_token(TokenType::TK_COMMA, s, 1); + case ';': return make_token(TokenType::TK_SEMICOLON, s, 1); + case '.': return make_token(TokenType::TK_DOT, s, 1); + case '*': return make_token(TokenType::TK_ASTERISK, s, 1); + case '+': return make_token(TokenType::TK_PLUS, s, 1); + case '-': return make_token(TokenType::TK_MINUS, s, 1); + case '/': return make_token(TokenType::TK_SLASH, s, 1); + case '%': return make_token(TokenType::TK_PERCENT, s, 1); + case '=': return make_token(TokenType::TK_EQUAL, s, 1); + case '<': return make_token(TokenType::TK_LESS, s, 1); + case '>': return make_token(TokenType::TK_GREATER, s, 1); + case '&': return make_token(TokenType::TK_AMPERSAND, s, 1); + case '|': return make_token(TokenType::TK_PIPE, s, 1); + case '^': return make_token(TokenType::TK_CARET, s, 1); + case '~': return make_token(TokenType::TK_TILDE, s, 1); + case '!': return make_token(TokenType::TK_EXCLAIM, s, 1); + case ':': return make_token(TokenType::TK_COLON, s, 1); + case '?': return make_token(TokenType::TK_QUESTION, s, 1); + case '#': return make_token(TokenType::TK_HASH, s, 1); + default: return make_token(TokenType::TK_ERROR, s, 1); + } + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_TOKENIZER_H diff --git a/include/sql_parser/update_parser.h b/include/sql_parser/update_parser.h new file mode 100644 index 0000000..86310f5 --- /dev/null +++ b/include/sql_parser/update_parser.h @@ -0,0 +1,296 @@ +#ifndef SQL_PARSER_UPDATE_PARSER_H +#define SQL_PARSER_UPDATE_PARSER_H + +#include "sql_parser/common.h" +#include "sql_parser/token.h" +#include "sql_parser/tokenizer.h" +#include "sql_parser/ast.h" +#include "sql_parser/arena.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/table_ref_parser.h" + +namespace sql_parser { + +template +class UpdateParser { +public: + UpdateParser(Tokenizer& tokenizer, Arena& arena) + : tok_(tokenizer), arena_(arena), expr_parser_(tokenizer, arena), + table_ref_parser_(tokenizer, arena, expr_parser_) {} + + // Parse UPDATE statement (UPDATE keyword already consumed). + AstNode* parse() { + AstNode* root = make_node(arena_, NodeType::NODE_UPDATE_STMT); + if (!root) return nullptr; + + if constexpr (D == Dialect::MySQL) { + return parse_mysql(root); + } else { + return parse_pgsql(root); + } + } + +private: + Tokenizer& tok_; + Arena& arena_; + ExpressionParser expr_parser_; + TableRefParser table_ref_parser_; + + // ---- MySQL UPDATE ---- + // UPDATE [LOW_PRIORITY] [IGNORE] table_references SET col=expr [,...] [WHERE] [ORDER BY] [LIMIT] + + AstNode* parse_mysql(AstNode* root) { + // Options: LOW_PRIORITY, IGNORE + AstNode* opts = parse_stmt_options(); + if (opts) root->add_child(opts); + + // Table references (supports JOINs for multi-table UPDATE) + // Use parse_from_clause which handles comma-joins and explicit JOINs + AstNode* from = table_ref_parser_.parse_from_clause(); + if (from) { + // For single-table UPDATE, hoist the single TABLE_REF as direct child + // For multi-table, keep the FROM_CLAUSE + int ref_count = 0; + bool has_join = false; + for (const AstNode* c = from->first_child; c; c = c->next_sibling) { + if (c->type == NodeType::NODE_JOIN_CLAUSE) has_join = true; + ++ref_count; + } + if (ref_count == 1 && !has_join) { + // Single table -- add TABLE_REF directly + root->add_child(from->first_child); + } else { + // Multi-table -- keep FROM_CLAUSE + root->add_child(from); + } + } + + // SET keyword + if (tok_.peek().type == TokenType::TK_SET) { + tok_.skip(); + AstNode* set_clause = parse_update_set_clause(); + if (set_clause) root->add_child(set_clause); + } + + // WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + + // ORDER BY (single-table only) + if (tok_.peek().type == TokenType::TK_ORDER) { + tok_.skip(); + if (tok_.peek().type == TokenType::TK_BY) tok_.skip(); + AstNode* order_by = parse_order_by(); + if (order_by) root->add_child(order_by); + } + + // LIMIT (single-table only) + if (tok_.peek().type == TokenType::TK_LIMIT) { + tok_.skip(); + AstNode* limit = parse_limit(); + if (limit) root->add_child(limit); + } + + return root; + } + + // ---- PostgreSQL UPDATE ---- + // UPDATE [ONLY] table [[AS] alias] SET col=expr [,...] [FROM from_list] [WHERE] [RETURNING] + + AstNode* parse_pgsql(AstNode* root) { + // Optional ONLY keyword + if (tok_.peek().type == TokenType::TK_ONLY) { + AstNode* opts = make_node(arena_, NodeType::NODE_STMT_OPTIONS); + Token only_tok = tok_.next_token(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, only_tok.text)); + root->add_child(opts); + } + + // Single table reference with optional alias + AstNode* table_ref = table_ref_parser_.parse_table_reference(); + if (table_ref) root->add_child(table_ref); + + // SET keyword + if (tok_.peek().type == TokenType::TK_SET) { + tok_.skip(); + AstNode* set_clause = parse_update_set_clause(); + if (set_clause) root->add_child(set_clause); + } + + // FROM clause (PostgreSQL: additional table sources) + if (tok_.peek().type == TokenType::TK_FROM) { + tok_.skip(); + AstNode* from = table_ref_parser_.parse_from_clause(); + if (from) root->add_child(from); + } + + // WHERE + if (tok_.peek().type == TokenType::TK_WHERE) { + tok_.skip(); + AstNode* where = parse_where_clause(); + if (where) root->add_child(where); + } + + // RETURNING + if (tok_.peek().type == TokenType::TK_RETURNING) { + AstNode* ret = parse_returning(); + if (ret) root->add_child(ret); + } + + return root; + } + + // ---- Shared helpers ---- + + // Parse MySQL options: LOW_PRIORITY, IGNORE + AstNode* parse_stmt_options() { + AstNode* opts = nullptr; + while (true) { + Token t = tok_.peek(); + if (t.type == TokenType::TK_LOW_PRIORITY || + t.type == TokenType::TK_IGNORE) { + if (!opts) opts = make_node(arena_, NodeType::NODE_STMT_OPTIONS); + tok_.skip(); + opts->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, t.text)); + } else { + break; + } + } + return opts; + } + + // Parse SET clause: col=expr [, col=expr ...] + AstNode* parse_update_set_clause() { + AstNode* set_clause = make_node(arena_, NodeType::NODE_UPDATE_SET_CLAUSE); + if (!set_clause) return nullptr; + + while (true) { + AstNode* item = parse_set_item(); + if (item) set_clause->add_child(item); + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return set_clause; + } + + // Parse a single col=expr pair + AstNode* parse_set_item() { + AstNode* item = make_node(arena_, NodeType::NODE_UPDATE_SET_ITEM); + if (!item) return nullptr; + + // Column name (may be qualified: table.col) + Token col = tok_.next_token(); + if (tok_.peek().type == TokenType::TK_DOT) { + tok_.skip(); + Token actual_col = tok_.next_token(); + AstNode* qname = make_node(arena_, NodeType::NODE_QUALIFIED_NAME); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, col.text)); + qname->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, actual_col.text)); + item->add_child(qname); + } else { + item->add_child(make_node(arena_, NodeType::NODE_COLUMN_REF, col.text)); + } + + // = sign + if (tok_.peek().type == TokenType::TK_EQUAL) { + tok_.skip(); + } + + // Expression value + AstNode* val = expr_parser_.parse(); + if (val) item->add_child(val); + + return item; + } + + // Parse WHERE clause + AstNode* parse_where_clause() { + AstNode* where = make_node(arena_, NodeType::NODE_WHERE_CLAUSE); + if (!where) return nullptr; + AstNode* expr = expr_parser_.parse(); + if (expr) where->add_child(expr); + return where; + } + + // Parse ORDER BY clause + AstNode* parse_order_by() { + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (!order_by) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + + // Optional ASC/DESC + Token dir = tok_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tok_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + + order_by->add_child(item); + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + return order_by; + } + + // Parse LIMIT clause + AstNode* parse_limit() { + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (!limit) return nullptr; + + AstNode* count = expr_parser_.parse(); + if (count) limit->add_child(count); + + return limit; + } + + // Parse PostgreSQL RETURNING clause + AstNode* parse_returning() { + if (tok_.peek().type != TokenType::TK_RETURNING) return nullptr; + tok_.skip(); // RETURNING + + AstNode* ret = make_node(arena_, NodeType::NODE_RETURNING_CLAUSE); + if (!ret) return nullptr; + + while (true) { + AstNode* expr = expr_parser_.parse(); + if (!expr) break; + ret->add_child(expr); + + // Check for optional alias + Token next = tok_.peek(); + if (next.type == TokenType::TK_AS) { + tok_.skip(); + Token alias_name = tok_.next_token(); + ret->add_child(make_node(arena_, NodeType::NODE_ALIAS, alias_name.text)); + } + + if (tok_.peek().type == TokenType::TK_COMMA) { + tok_.skip(); + } else { + break; + } + } + + return ret; + } +}; + +} // namespace sql_parser + +#endif // SQL_PARSER_UPDATE_PARSER_H diff --git a/src/sql_parser/arena.cpp b/src/sql_parser/arena.cpp new file mode 100644 index 0000000..17c19c5 --- /dev/null +++ b/src/sql_parser/arena.cpp @@ -0,0 +1,88 @@ +#include "sql_parser/arena.h" +#include + +namespace sql_parser { + +Arena::Block* Arena::allocate_block(size_t capacity) { + void* mem = std::malloc(sizeof(Block) + capacity); + if (!mem) return nullptr; + Block* block = static_cast(mem); + block->next = nullptr; + block->capacity = capacity; + block->used = 0; + return block; +} + +Arena::Arena(size_t block_size, size_t max_size) + : block_size_(block_size), max_size_(max_size), total_allocated_(0) { + primary_ = allocate_block(block_size_); + current_ = primary_; + total_allocated_ = block_size_; +} + +Arena::~Arena() { + Block* b = primary_; + while (b) { + Block* next = b->next; + std::free(b); + b = next; + } +} + +void* Arena::allocate(size_t bytes) { + bytes = (bytes + 7) & ~size_t(7); + + if (current_->used + bytes <= current_->capacity) { + void* ptr = current_->data() + current_->used; + current_->used += bytes; + return ptr; + } + + size_t new_cap = (bytes > block_size_) ? bytes : block_size_; + if (total_allocated_ + new_cap > max_size_) { + return nullptr; + } + + Block* new_block = allocate_block(new_cap); + if (!new_block) return nullptr; + + current_->next = new_block; + current_ = new_block; + total_allocated_ += new_cap; + + void* ptr = current_->data() + current_->used; + current_->used += bytes; + return ptr; +} + +StringRef Arena::allocate_string(const char* src, uint32_t len) { + void* mem = allocate(len); + if (!mem) return StringRef{nullptr, 0}; + std::memcpy(mem, src, len); + return StringRef{static_cast(mem), len}; +} + +void Arena::reset() { + Block* b = primary_->next; + while (b) { + Block* next = b->next; + std::free(b); + b = next; + } + primary_->next = nullptr; + primary_->used = 0; + current_ = primary_; + total_allocated_ = block_size_; +} + +size_t Arena::bytes_used() const { + size_t used = 0; + const Block* b = primary_; + while (b) { + used += b->used; + b = b->next; + } + return used; +} + +} // namespace sql_parser diff --git a/src/sql_parser/parser.cpp b/src/sql_parser/parser.cpp new file mode 100644 index 0000000..20344f6 --- /dev/null +++ b/src/sql_parser/parser.cpp @@ -0,0 +1,738 @@ +#include "sql_parser/parser.h" +#include "sql_parser/expression_parser.h" +#include "sql_parser/set_parser.h" +#include "sql_parser/select_parser.h" +#include "sql_parser/compound_query_parser.h" +#include "sql_parser/insert_parser.h" +#include "sql_parser/update_parser.h" +#include "sql_parser/delete_parser.h" + +namespace sql_parser { + +template +Parser::Parser(const ParserConfig& config) + : arena_(config.arena_block_size, config.arena_max_size), + stmt_cache_(config.stmt_cache_capacity) {} + +template +void Parser::reset() { + arena_.reset(); +} + +template +ParseResult Parser::parse(const char* sql, size_t len) { + arena_.reset(); + tokenizer_.reset(sql, len); + return classify_and_dispatch(); +} + +template +ParseResult Parser::classify_and_dispatch() { + Token first = tokenizer_.next_token(); + + if (first.type == TokenType::TK_EOF) { + ParseResult r; + r.status = ParseResult::ERROR; + r.stmt_type = StmtType::UNKNOWN; + return r; + } + + switch (first.type) { + case TokenType::TK_SELECT: return parse_select(); + case TokenType::TK_LPAREN: { + // Parenthesized SELECT / compound query: (SELECT ...) UNION ... + Token next = tokenizer_.peek(); + if (next.type == TokenType::TK_SELECT || next.type == TokenType::TK_LPAREN) { + return parse_select_from_lparen(); + } + return extract_unknown(first); + } + case TokenType::TK_SET: return parse_set(); + case TokenType::TK_INSERT: return parse_insert(false); + case TokenType::TK_UPDATE: return parse_update(); + case TokenType::TK_DELETE: return parse_delete(); + case TokenType::TK_REPLACE: return parse_insert(true); + case TokenType::TK_BEGIN: + case TokenType::TK_START: + case TokenType::TK_COMMIT: + case TokenType::TK_ROLLBACK: + case TokenType::TK_SAVEPOINT:return extract_transaction(first); + case TokenType::TK_USE: return extract_use(first); + case TokenType::TK_SHOW: return extract_show(first); + case TokenType::TK_PREPARE: return extract_prepare(first); + case TokenType::TK_EXECUTE: return extract_execute(first); + case TokenType::TK_DEALLOCATE: return extract_deallocate(first); + case TokenType::TK_CREATE: + case TokenType::TK_ALTER: + case TokenType::TK_DROP: + case TokenType::TK_TRUNCATE: return extract_ddl(first); + case TokenType::TK_GRANT: + case TokenType::TK_REVOKE: return extract_acl(first); + case TokenType::TK_LOCK: + case TokenType::TK_UNLOCK: return extract_lock(first); + case TokenType::TK_LOAD: return extract_load(first); + case TokenType::TK_RESET: return extract_reset(first); + default: return extract_unknown(first); + } +} + +// ---- Tier 1 stubs ---- + +template +ParseResult Parser::parse_select() { + ParseResult r; + r.stmt_type = StmtType::SELECT; + + CompoundQueryParser compound_parser(tokenizer_, arena_); + AstNode* ast = compound_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_select_from_lparen() { + // Called when classifier consumed '(' and peeked SELECT or '(' + // We need to parse the inner compound query, then check for set operators + // after the closing ')'. + // + // Strategy: parse inner as a fresh compound expression, expect ')', + // then check if a set operator follows (making this a compound query). + + ParseResult r; + r.stmt_type = StmtType::SELECT; + + // We're inside '(' already consumed. + // Parse inner: could be SELECT or another '(' + AstNode* inner = nullptr; + if (tokenizer_.peek().type == TokenType::TK_SELECT) { + tokenizer_.skip(); // consume SELECT + SelectParser sp(tokenizer_, arena_, true); + inner = sp.parse(); + + // Check for set operators inside the parens + Token t = tokenizer_.peek(); + while (t.type == TokenType::TK_UNION || + t.type == TokenType::TK_INTERSECT || + t.type == TokenType::TK_EXCEPT) { + tokenizer_.skip(); + StringRef op_text = t.text; + uint16_t flags = 0; + if (tokenizer_.peek().type == TokenType::TK_ALL) { + tokenizer_.skip(); + flags = FLAG_SET_OP_ALL; + } + // Next SELECT + if (tokenizer_.peek().type == TokenType::TK_SELECT) { + tokenizer_.skip(); + } + SelectParser sp2(tokenizer_, arena_, true); + AstNode* right = sp2.parse(); + + AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text); + if (setop) { + setop->flags = flags; + setop->add_child(inner); + if (right) setop->add_child(right); + inner = setop; + } + t = tokenizer_.peek(); + } + } else { + // Nested parenthesized -- recursively handle + // This is an edge case; for now parse as compound + CompoundQueryParser cp(tokenizer_, arena_); + inner = cp.parse(); + } + + // Expect closing ')' + if (tokenizer_.peek().type == TokenType::TK_RPAREN) { + tokenizer_.skip(); + } + + // Now check if a set operator follows after the ')' + Token t = tokenizer_.peek(); + if (t.type == TokenType::TK_UNION || + t.type == TokenType::TK_INTERSECT || + t.type == TokenType::TK_EXCEPT) { + // This is a compound query starting with a parenthesized operand. + // Use CompoundQueryParser to continue, but we already have the left operand. + // We'll build the compound manually. + AstNode* left = inner; + while (true) { + t = tokenizer_.peek(); + if (t.type != TokenType::TK_UNION && + t.type != TokenType::TK_INTERSECT && + t.type != TokenType::TK_EXCEPT) break; + + tokenizer_.skip(); + StringRef op_text = t.text; + uint16_t flags = 0; + if (tokenizer_.peek().type == TokenType::TK_ALL) { + tokenizer_.skip(); + flags = FLAG_SET_OP_ALL; + } + + AstNode* right = nullptr; + if (tokenizer_.peek().type == TokenType::TK_LPAREN) { + // Parenthesized right operand + tokenizer_.skip(); + if (tokenizer_.peek().type == TokenType::TK_SELECT) { + tokenizer_.skip(); + } + SelectParser sp3(tokenizer_, arena_, true); + right = sp3.parse(); + if (tokenizer_.peek().type == TokenType::TK_RPAREN) { + tokenizer_.skip(); + } + } else if (tokenizer_.peek().type == TokenType::TK_SELECT) { + tokenizer_.skip(); + SelectParser sp3(tokenizer_, arena_, true); + right = sp3.parse(); + } + + AstNode* setop = make_node(arena_, NodeType::NODE_SET_OPERATION, op_text); + if (setop) { + setop->flags = flags; + setop->add_child(left); + if (right) setop->add_child(right); + left = setop; + } + } + + // Wrap in COMPOUND_QUERY + AstNode* compound = make_node(arena_, NodeType::NODE_COMPOUND_QUERY); + if (compound) { + compound->add_child(left); + + // Trailing ORDER BY + if (tokenizer_.peek().type == TokenType::TK_ORDER) { + tokenizer_.skip(); + if (tokenizer_.peek().type == TokenType::TK_BY) tokenizer_.skip(); + ExpressionParser ep(tokenizer_, arena_); + AstNode* order_by = make_node(arena_, NodeType::NODE_ORDER_BY_CLAUSE); + if (order_by) { + while (true) { + AstNode* expr = ep.parse(); + if (!expr) break; + AstNode* item = make_node(arena_, NodeType::NODE_ORDER_BY_ITEM); + item->add_child(expr); + Token dir = tokenizer_.peek(); + if (dir.type == TokenType::TK_ASC || dir.type == TokenType::TK_DESC) { + tokenizer_.skip(); + item->add_child(make_node(arena_, NodeType::NODE_IDENTIFIER, dir.text)); + } + order_by->add_child(item); + if (tokenizer_.peek().type == TokenType::TK_COMMA) { + tokenizer_.skip(); + } else { + break; + } + } + compound->add_child(order_by); + } + } + + // Trailing LIMIT + if (tokenizer_.peek().type == TokenType::TK_LIMIT) { + tokenizer_.skip(); + ExpressionParser ep(tokenizer_, arena_); + AstNode* limit = make_node(arena_, NodeType::NODE_LIMIT_CLAUSE); + if (limit) { + AstNode* val = ep.parse(); + if (val) limit->add_child(val); + if (tokenizer_.peek().type == TokenType::TK_OFFSET) { + tokenizer_.skip(); + AstNode* off = ep.parse(); + if (off) limit->add_child(off); + } + compound->add_child(limit); + } + } + + r.status = ParseResult::OK; + r.ast = compound; + } + } else { + // Just a parenthesized SELECT, no compound + if (inner) { + r.status = ParseResult::OK; + r.ast = inner; + } else { + r.status = ParseResult::PARTIAL; + } + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_set() { + ParseResult r; + r.stmt_type = StmtType::SET; + + SetParser set_parser(tokenizer_, arena_); + AstNode* ast = set_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_insert(bool is_replace) { + ParseResult r; + r.stmt_type = is_replace ? StmtType::REPLACE : StmtType::INSERT; + + InsertParser insert_parser(tokenizer_, arena_, is_replace); + AstNode* ast = insert_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + + // Extract table_name/schema_name from AST for backward compatibility + for (const AstNode* child = ast->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_TABLE_REF) { + const AstNode* name_node = child->first_child; + if (name_node && name_node->type == NodeType::NODE_QUALIFIED_NAME) { + // schema.table + const AstNode* schema = name_node->first_child; + const AstNode* table = schema ? schema->next_sibling : nullptr; + if (schema) r.schema_name = schema->value(); + if (table) r.table_name = table->value(); + } else if (name_node && name_node->type == NodeType::NODE_IDENTIFIER) { + r.table_name = name_node->value(); + } + break; + } + } + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_update() { + ParseResult r; + r.stmt_type = StmtType::UPDATE; + + UpdateParser update_parser(tokenizer_, arena_); + AstNode* ast = update_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + + // Extract table_name/schema_name from AST for backward compatibility + for (const AstNode* child = ast->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_TABLE_REF) { + const AstNode* name_node = child->first_child; + if (name_node && name_node->type == NodeType::NODE_QUALIFIED_NAME) { + const AstNode* schema = name_node->first_child; + const AstNode* table = schema ? schema->next_sibling : nullptr; + if (schema) r.schema_name = schema->value(); + if (table) r.table_name = table->value(); + } else if (name_node && name_node->type == NodeType::NODE_IDENTIFIER) { + r.table_name = name_node->value(); + } + break; + } + } + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::parse_delete() { + ParseResult r; + r.stmt_type = StmtType::DELETE_STMT; + + DeleteParser delete_parser(tokenizer_, arena_); + AstNode* ast = delete_parser.parse(); + + if (ast) { + r.status = ParseResult::OK; + r.ast = ast; + + // Extract table_name/schema_name from AST for backward compatibility + for (const AstNode* child = ast->first_child; child; child = child->next_sibling) { + if (child->type == NodeType::NODE_TABLE_REF) { + const AstNode* name_node = child->first_child; + if (name_node && name_node->type == NodeType::NODE_QUALIFIED_NAME) { + const AstNode* schema = name_node->first_child; + const AstNode* table = schema ? schema->next_sibling : nullptr; + if (schema) r.schema_name = schema->value(); + if (table) r.table_name = table->value(); + } else if (name_node && name_node->type == NodeType::NODE_IDENTIFIER) { + r.table_name = name_node->value(); + } + break; + } + } + } else { + r.status = ParseResult::PARTIAL; + } + + scan_to_end(r); + return r; +} + +// ---- Helpers ---- + +template +Token Parser::read_table_name(StringRef& schema_out) { + Token name = tokenizer_.next_token(); + if (name.type != TokenType::TK_IDENTIFIER && + name.type != TokenType::TK_EOF) { + // Keywords used as table names (e.g., CREATE TABLE `user`) + // The tokenizer returns keyword tokens for reserved words. + // Accept any non-punctuation token as a potential name. + } + + // Check for qualified name: schema.table + if (tokenizer_.peek().type == TokenType::TK_DOT) { + schema_out = name.text; + tokenizer_.skip(); // consume dot + Token table = tokenizer_.next_token(); + return table; + } + + schema_out = StringRef{}; + return name; +} + +template +void Parser::scan_to_end(ParseResult& result) { + while (true) { + Token t = tokenizer_.next_token(); + if (t.type == TokenType::TK_EOF) break; + if (t.type == TokenType::TK_SEMICOLON) { + Token next = tokenizer_.peek(); + if (next.type != TokenType::TK_EOF) { + const char* remaining_start = next.text.ptr; + const char* input_end = tokenizer_.input_end(); + result.remaining = StringRef{ + remaining_start, + static_cast(input_end - remaining_start) + }; + } + break; + } + } +} + +// ---- Tier 2 Extractors ---- + +template +ParseResult Parser::extract_insert(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::INSERT; + + // Expect optional INTO + Token t = tokenizer_.peek(); + if (t.type == TokenType::TK_INTO) { + tokenizer_.skip(); + } + + // Read table name + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_update(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::UPDATE; + + // Optional LOW_PRIORITY / IGNORE + Token t = tokenizer_.peek(); + while (t.type == TokenType::TK_LOW_PRIORITY || t.type == TokenType::TK_IGNORE) { + tokenizer_.skip(); + t = tokenizer_.peek(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_delete(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::DELETE_STMT; + + // Optional LOW_PRIORITY / QUICK / IGNORE + Token t = tokenizer_.peek(); + while (t.type == TokenType::TK_LOW_PRIORITY || + t.type == TokenType::TK_QUICK || + t.type == TokenType::TK_IGNORE) { + tokenizer_.skip(); + t = tokenizer_.peek(); + } + + // Expect FROM + if (tokenizer_.peek().type == TokenType::TK_FROM) { + tokenizer_.skip(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_replace(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::REPLACE; + + if (tokenizer_.peek().type == TokenType::TK_INTO) { + tokenizer_.skip(); + } + + Token table = read_table_name(r.schema_name); + r.table_name = table.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_transaction(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + + switch (first.type) { + case TokenType::TK_BEGIN: + r.stmt_type = StmtType::BEGIN; + break; + case TokenType::TK_START: + r.stmt_type = StmtType::START_TRANSACTION; + // consume TRANSACTION if present + if (tokenizer_.peek().type == TokenType::TK_TRANSACTION) + tokenizer_.skip(); + break; + case TokenType::TK_COMMIT: + r.stmt_type = StmtType::COMMIT; + break; + case TokenType::TK_ROLLBACK: + r.stmt_type = StmtType::ROLLBACK; + break; + case TokenType::TK_SAVEPOINT: + r.stmt_type = StmtType::SAVEPOINT; + break; + default: + r.stmt_type = StmtType::UNKNOWN; + break; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_use(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::USE; + + Token db = tokenizer_.next_token(); + r.database_name = db.text; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_show(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::SHOW; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_prepare(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::PREPARE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_execute(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::EXECUTE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_deallocate(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::DEALLOCATE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_ddl(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + + switch (first.type) { + case TokenType::TK_CREATE: r.stmt_type = StmtType::CREATE; break; + case TokenType::TK_ALTER: r.stmt_type = StmtType::ALTER; break; + case TokenType::TK_DROP: r.stmt_type = StmtType::DROP; break; + case TokenType::TK_TRUNCATE: r.stmt_type = StmtType::TRUNCATE; break; + default: r.stmt_type = StmtType::UNKNOWN; break; + } + + // Try to extract object name: CREATE/ALTER/DROP [IF EXISTS/NOT EXISTS] TABLE/INDEX/VIEW name + Token t = tokenizer_.next_token(); + + // Skip optional IF [NOT] EXISTS + if (t.type == TokenType::TK_IF) { + t = tokenizer_.next_token(); // NOT or EXISTS + if (t.type == TokenType::TK_NOT) { + t = tokenizer_.next_token(); // EXISTS + } + // Skip EXISTS + t = tokenizer_.next_token(); // should be TABLE/INDEX/etc. + } + + // Now t should be TABLE, INDEX, VIEW, DATABASE, SCHEMA, or a name + if (t.type == TokenType::TK_TABLE || t.type == TokenType::TK_INDEX || + t.type == TokenType::TK_VIEW || t.type == TokenType::TK_DATABASE || + t.type == TokenType::TK_SCHEMA) { + // Optional IF [NOT] EXISTS after object type for CREATE/DROP + Token maybe_if = tokenizer_.peek(); + if (maybe_if.type == TokenType::TK_IF) { + tokenizer_.skip(); // IF + Token next = tokenizer_.next_token(); + if (next.type == TokenType::TK_NOT) { + tokenizer_.skip(); // EXISTS + } + } + Token name = read_table_name(r.schema_name); + r.table_name = name.text; + } + + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_acl(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = (first.type == TokenType::TK_GRANT) ? StmtType::GRANT : StmtType::REVOKE; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_lock(const Token& first) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = (first.type == TokenType::TK_LOCK) ? StmtType::LOCK : StmtType::UNLOCK; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_load(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::LOAD_DATA; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_reset(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::RESET; + scan_to_end(r); + return r; +} + +template +ParseResult Parser::extract_unknown(const Token& /* first */) { + ParseResult r; + r.status = ParseResult::OK; + r.stmt_type = StmtType::UNKNOWN; + scan_to_end(r); + return r; +} + +// ---- Prepared statement support ---- + +template +ParseResult Parser::parse_and_cache(const char* sql, size_t len, uint32_t stmt_id) { + ParseResult r = parse(sql, len); + if (r.ast) { + stmt_cache_.store(stmt_id, r.stmt_type, r.ast); + } + return r; +} + +template +ParseResult Parser::execute(uint32_t stmt_id, const ParamBindings& params) { + ParseResult r; + const CachedStmt* cached = stmt_cache_.lookup(stmt_id); + if (!cached) { + r.status = ParseResult::ERROR; + r.stmt_type = StmtType::UNKNOWN; + return r; + } + r.status = ParseResult::OK; + r.stmt_type = cached->stmt_type; + r.ast = cached->ast; + r.bindings = params; + return r; +} + +template +void Parser::prepare_cache_evict(uint32_t stmt_id) { + stmt_cache_.evict(stmt_id); +} + +// ---- Explicit template instantiations ---- + +template class Parser; +template class Parser; + +} // namespace sql_parser diff --git a/tests/test_arena.cpp b/tests/test_arena.cpp new file mode 100644 index 0000000..b3f4d89 --- /dev/null +++ b/tests/test_arena.cpp @@ -0,0 +1,78 @@ +#include +#include "sql_parser/arena.h" + +using namespace sql_parser; + +TEST(ArenaTest, AllocateAndReset) { + Arena arena(4096); + void* p1 = arena.allocate(64); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(64); + ASSERT_NE(p2, nullptr); + EXPECT_NE(p1, p2); + + arena.reset(); + void* p3 = arena.allocate(64); + ASSERT_NE(p3, nullptr); + EXPECT_EQ(p1, p3); +} + +TEST(ArenaTest, AllocateAligned) { + Arena arena(4096); + (void)arena.allocate(1); // advance cursor by 1 byte + void* p2 = arena.allocate(8); + EXPECT_EQ(reinterpret_cast(p2) % 8, 0u); +} + +TEST(ArenaTest, OverflowToNewBlock) { + Arena arena(128); + void* p1 = arena.allocate(100); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(100); + ASSERT_NE(p2, nullptr); + EXPECT_NE(p1, p2); +} + +TEST(ArenaTest, ResetFreesOverflowBlocks) { + Arena arena(128); + arena.allocate(100); + arena.allocate(100); + arena.reset(); + void* p = arena.allocate(64); + ASSERT_NE(p, nullptr); +} + +TEST(ArenaTest, MaxSizeEnforced) { + Arena arena(128, 256); + void* p1 = arena.allocate(100); + ASSERT_NE(p1, nullptr); + void* p2 = arena.allocate(100); + ASSERT_NE(p2, nullptr); + void* p3 = arena.allocate(100); + EXPECT_EQ(p3, nullptr); +} + +TEST(ArenaTest, AllocateTyped) { + Arena arena(4096); + + struct TestStruct { + int a; + double b; + }; + + TestStruct* ts = arena.allocate_typed(); + ASSERT_NE(ts, nullptr); + ts->a = 42; + ts->b = 3.14; + EXPECT_EQ(ts->a, 42); + EXPECT_DOUBLE_EQ(ts->b, 3.14); +} + +TEST(ArenaTest, AllocateString) { + Arena arena(4096); + const char* src = "hello world"; + StringRef ref = arena.allocate_string(src, 11); + EXPECT_EQ(ref.len, 11u); + EXPECT_EQ(std::memcmp(ref.ptr, "hello world", 11), 0); + EXPECT_NE(ref.ptr, src); +} diff --git a/tests/test_classifier.cpp b/tests/test_classifier.cpp new file mode 100644 index 0000000..0462ac3 --- /dev/null +++ b/tests/test_classifier.cpp @@ -0,0 +1,182 @@ +#include +#include "sql_parser/parser.h" + +using namespace sql_parser; + +// ========== MySQL Classifier Tests ========== + +class MySQLClassifierTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(MySQLClassifierTest, ClassifySelect) { + auto r = parser.parse("SELECT * FROM users", 19); + // SELECT is Tier 1 — for now returns PARTIAL until deep parser is implemented + EXPECT_EQ(r.stmt_type, StmtType::SELECT); +} + +TEST_F(MySQLClassifierTest, ClassifyInsert) { + auto r = parser.parse("INSERT INTO users VALUES (1, 'a')", 33); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyInsertQualified) { + auto r = parser.parse("INSERT INTO mydb.users VALUES (1)", 33); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.schema_name.ptr, r.schema_name.len), "mydb"); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyUpdate) { + auto r = parser.parse("UPDATE users SET name='x'", 25); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyDelete) { + auto r = parser.parse("DELETE FROM users WHERE id=1", 28); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifySet) { + auto r = parser.parse("SET autocommit=0", 16); + EXPECT_EQ(r.stmt_type, StmtType::SET); +} + +TEST_F(MySQLClassifierTest, ClassifyUse) { + auto r = parser.parse("USE mydb", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::USE); + EXPECT_EQ(std::string(r.database_name.ptr, r.database_name.len), "mydb"); +} + +TEST_F(MySQLClassifierTest, ClassifyBegin) { + auto r = parser.parse("BEGIN", 5); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::BEGIN); +} + +TEST_F(MySQLClassifierTest, ClassifyStartTransaction) { + auto r = parser.parse("START TRANSACTION", 17); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::START_TRANSACTION); +} + +TEST_F(MySQLClassifierTest, ClassifyCommit) { + auto r = parser.parse("COMMIT", 6); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::COMMIT); +} + +TEST_F(MySQLClassifierTest, ClassifyRollback) { + auto r = parser.parse("ROLLBACK", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::ROLLBACK); +} + +TEST_F(MySQLClassifierTest, ClassifyCreateTable) { + auto r = parser.parse("CREATE TABLE users (id INT)", 27); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::CREATE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyDropTable) { + auto r = parser.parse("DROP TABLE users", 16); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DROP); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyShow) { + auto r = parser.parse("SHOW TABLES", 11); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SHOW); +} + +TEST_F(MySQLClassifierTest, ClassifyReplace) { + auto r = parser.parse("REPLACE INTO users VALUES (1)", 29); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(MySQLClassifierTest, ClassifyGrant) { + auto r = parser.parse("GRANT SELECT ON users TO 'app'", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::GRANT); +} + +TEST_F(MySQLClassifierTest, ClassifyRevoke) { + auto r = parser.parse("REVOKE ALL ON users FROM 'app'", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REVOKE); +} + +TEST_F(MySQLClassifierTest, ClassifyLock) { + auto r = parser.parse("LOCK TABLES users WRITE", 23); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::LOCK); +} + +TEST_F(MySQLClassifierTest, ClassifyDeallocate) { + auto r = parser.parse("DEALLOCATE PREPARE stmt1", 24); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DEALLOCATE); +} + +TEST_F(MySQLClassifierTest, ClassifyUnknown) { + auto r = parser.parse("EXPLAIN SELECT 1", 16); + EXPECT_EQ(r.stmt_type, StmtType::UNKNOWN); +} + +TEST_F(MySQLClassifierTest, EmptyInput) { + auto r = parser.parse("", 0); + EXPECT_EQ(r.status, ParseResult::ERROR); + EXPECT_EQ(r.stmt_type, StmtType::UNKNOWN); +} + +TEST_F(MySQLClassifierTest, MultiStatement) { + const char* sql = "BEGIN; SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.stmt_type, StmtType::BEGIN); + EXPECT_TRUE(r.has_remaining()); + // remaining should point to " SELECT 1" + EXPECT_GT(r.remaining.len, 0u); +} + +TEST_F(MySQLClassifierTest, CaseInsensitive) { + auto r = parser.parse("insert into USERS values (1)", 28); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "USERS"); +} + +// ========== PostgreSQL Classifier Tests ========== + +class PgSQLClassifierTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLClassifierTest, ClassifySelect) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); +} + +TEST_F(PgSQLClassifierTest, ClassifyInsert) { + auto r = parser.parse("INSERT INTO users VALUES (1)", 28); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + EXPECT_EQ(std::string(r.table_name.ptr, r.table_name.len), "users"); +} + +TEST_F(PgSQLClassifierTest, ClassifyReset) { + auto r = parser.parse("RESET ALL", 9); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::RESET); +} diff --git a/tests/test_compound.cpp b/tests/test_compound.cpp new file mode 100644 index 0000000..6f122b9 --- /dev/null +++ b/tests/test_compound.cpp @@ -0,0 +1,326 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLCompoundTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Simple SELECT (no compound) ========== +// CompoundQueryParser must return bare NODE_SELECT_STMT when no set operator follows + +TEST_F(MySQLCompoundTest, PlainSelectUnchanged) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + // Should be NODE_SELECT_STMT, NOT NODE_COMPOUND_QUERY + EXPECT_EQ(r.ast->type, NodeType::NODE_SELECT_STMT); +} + +// ========== UNION ========== + +TEST_F(MySQLCompoundTest, SimpleUnion) { + const char* sql = "SELECT 1 UNION SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + auto* setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(setop, nullptr); +} + +TEST_F(MySQLCompoundTest, UnionAll) { + const char* sql = "SELECT 1 UNION ALL SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, UnionThreeSelects) { + const char* sql = "SELECT 1 UNION SELECT 2 UNION SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, UnionWithOrderBy) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); +} + +TEST_F(MySQLCompoundTest, UnionWithLimit) { + const char* sql = "SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLCompoundTest, UnionWithOrderByAndLimit) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== INTERSECT ========== + +TEST_F(MySQLCompoundTest, SimpleIntersect) { + const char* sql = "SELECT 1 INTERSECT SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, IntersectAll) { + const char* sql = "SELECT 1 INTERSECT ALL SELECT 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== EXCEPT ========== + +TEST_F(MySQLCompoundTest, SimpleExcept) { + const char* sql = "SELECT 1 EXCEPT SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, ExceptAll) { + const char* sql = "SELECT 1 EXCEPT ALL SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Precedence: INTERSECT > UNION/EXCEPT ========== + +TEST_F(MySQLCompoundTest, IntersectBindsTighterThanUnion) { + // SELECT 1 UNION SELECT 2 INTERSECT SELECT 3 + // Should parse as: SELECT 1 UNION (SELECT 2 INTERSECT SELECT 3) + const char* sql = "SELECT 1 UNION SELECT 2 INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + + // The top-level set operation should be UNION + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + // The value should contain "UNION" + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("UNION", 5)); + + // The right child of UNION should be a SET_OPERATION (INTERSECT) + const AstNode* left = top_setop->first_child; + ASSERT_NE(left, nullptr); + const AstNode* right = left->next_sibling; + ASSERT_NE(right, nullptr); + EXPECT_EQ(right->type, NodeType::NODE_SET_OPERATION); + StringRef right_op = right->value(); + EXPECT_TRUE(right_op.equals_ci("INTERSECT", 9)); +} + +TEST_F(MySQLCompoundTest, IntersectBindsTighterThanExcept) { + // SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3 + // Should parse as: SELECT 1 EXCEPT (SELECT 2 INTERSECT SELECT 3) + const char* sql = "SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("EXCEPT", 6)); +} + +// ========== Parenthesized nesting ========== + +TEST_F(MySQLCompoundTest, ParenthesizedUnion) { + const char* sql = "(SELECT 1) UNION (SELECT 2)"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(MySQLCompoundTest, ParenthesizedOverridesPrecedence) { + // (SELECT 1 UNION SELECT 2) INTERSECT SELECT 3 + // Parentheses force UNION to be evaluated first + const char* sql = "(SELECT 1 UNION SELECT 2) INTERSECT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); + + auto* top_setop = find_child(r.ast, NodeType::NODE_SET_OPERATION); + ASSERT_NE(top_setop, nullptr); + StringRef op_text = top_setop->value(); + EXPECT_TRUE(op_text.equals_ci("INTERSECT", 9)); +} + +// ========== Complex compound queries ========== + +TEST_F(MySQLCompoundTest, UnionWithFullSelects) { + const char* sql = "SELECT a, b FROM t1 WHERE x = 1 UNION ALL SELECT a, b FROM t2 WHERE y = 2 ORDER BY a LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +// ========== PostgreSQL compound queries ========== + +class PgSQLCompoundTest : public ::testing::Test { +protected: + Parser parser; + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLCompoundTest, SimpleUnion) { + const char* sql = "SELECT 1 UNION SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_COMPOUND_QUERY); +} + +TEST_F(PgSQLCompoundTest, IntersectExcept) { + const char* sql = "SELECT 1 INTERSECT SELECT 2 EXCEPT SELECT 3"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLCompoundTest, UnionReturnsCorrectDialect) { + const char* sql = "SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct CompoundTestCase { + const char* sql; + const char* description; +}; + +static const CompoundTestCase compound_bulk_cases[] = { + {"SELECT 1 UNION SELECT 2", "simple union"}, + {"SELECT 1 UNION ALL SELECT 2", "union all"}, + {"SELECT 1 UNION SELECT 2 UNION SELECT 3", "triple union"}, + {"SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3", "triple union all"}, + {"SELECT 1 INTERSECT SELECT 2", "simple intersect"}, + {"SELECT 1 INTERSECT ALL SELECT 2", "intersect all"}, + {"SELECT 1 EXCEPT SELECT 2", "simple except"}, + {"SELECT 1 EXCEPT ALL SELECT 2", "except all"}, + {"SELECT 1 UNION SELECT 2 INTERSECT SELECT 3", "union + intersect precedence"}, + {"SELECT 1 EXCEPT SELECT 2 INTERSECT SELECT 3", "except + intersect precedence"}, + {"(SELECT 1) UNION (SELECT 2)", "parenthesized"}, + {"(SELECT 1 UNION SELECT 2) INTERSECT SELECT 3", "paren override"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a", "trailing order by"}, + {"SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10", "trailing limit"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a LIMIT 5", "trailing order by + limit"}, + {"SELECT * FROM t1 WHERE x = 1 UNION SELECT * FROM t2 WHERE y = 2", "union with where"}, +}; + +TEST(MySQLCompoundBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : compound_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +TEST(PgSQLCompoundBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : compound_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const CompoundTestCase compound_roundtrip_cases[] = { + {"SELECT 1 UNION SELECT 2", "simple union"}, + {"SELECT 1 UNION ALL SELECT 2", "union all"}, + {"SELECT 1 INTERSECT SELECT 2", "intersect"}, + {"SELECT 1 EXCEPT SELECT 2", "except"}, + {"SELECT a FROM t1 UNION SELECT a FROM t2 ORDER BY a", "with order by"}, + {"SELECT a FROM t1 UNION ALL SELECT a FROM t2 LIMIT 10", "with limit"}, +}; + +TEST(MySQLCompoundRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : compound_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} diff --git a/tests/test_delete.cpp b/tests/test_delete.cpp new file mode 100644 index 0000000..cfab6a5 --- /dev/null +++ b/tests/test_delete.cpp @@ -0,0 +1,335 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLDeleteTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic DELETE ========== + +TEST_F(MySQLDeleteTest, SimpleDelete) { + auto r = parser.parse("DELETE FROM users WHERE id = 1", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_DELETE_STMT); +} + +TEST_F(MySQLDeleteTest, DeleteNoWhere) { + auto r = parser.parse("DELETE FROM users", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + EXPECT_EQ(where, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteQualifiedTable) { + auto r = parser.parse("DELETE FROM mydb.users WHERE id = 1", 35); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteComplexWhere) { + const char* sql = "DELETE FROM users WHERE status = 'inactive' AND last_login < '2020-01-01'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLDeleteTest, DeleteLowPriority) { + auto r = parser.parse("DELETE LOW_PRIORITY FROM users WHERE id = 1", 44); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteQuick) { + auto r = parser.parse("DELETE QUICK FROM users WHERE id = 1", 36); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteIgnore) { + auto r = parser.parse("DELETE IGNORE FROM users WHERE id = 1", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteAllOptions) { + auto r = parser.parse("DELETE LOW_PRIORITY QUICK IGNORE FROM users WHERE id = 1", 56); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL ORDER BY + LIMIT ========== + +TEST_F(MySQLDeleteTest, DeleteOrderByLimit) { + const char* sql = "DELETE FROM users WHERE active = 0 ORDER BY created_at ASC LIMIT 100"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLDeleteTest, DeleteLimitOnly) { + auto r = parser.parse("DELETE FROM users WHERE active = 0 LIMIT 100", 45); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== MySQL Multi-Table Form 1: DELETE t1, t2 FROM ... ========== + +TEST_F(MySQLDeleteTest, MultiTableForm1Single) { + const char* sql = "DELETE t1 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t2.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, MultiTableForm1Multiple) { + const char* sql = "DELETE t1, t2 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t1.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Multi-Table Form 2: DELETE FROM t1, t2 USING ... ========== + +TEST_F(MySQLDeleteTest, MultiTableForm2) { + const char* sql = "DELETE FROM t1, t2 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t1.status = 0"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLDeleteTest, MultiTableForm2Single) { + const char* sql = "DELETE FROM t1 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t2.bad = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== PostgreSQL DELETE ========== + +class PgSQLDeleteTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLDeleteTest, SimpleDelete) { + auto r = parser.parse("DELETE FROM users WHERE id = 1", 30); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteUsing) { + const char* sql = "DELETE FROM users USING orders WHERE users.id = orders.user_id AND orders.bad = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* using_clause = find_child(r.ast, NodeType::NODE_DELETE_USING_CLAUSE); + ASSERT_NE(using_clause, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteUsingMultiple) { + const char* sql = "DELETE FROM t1 USING t2, t3 WHERE t1.id = t2.fk AND t2.id = t3.fk"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteReturning) { + const char* sql = "DELETE FROM users WHERE id = 1 RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteReturningColumns) { + const char* sql = "DELETE FROM users WHERE id = 1 RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLDeleteTest, DeleteUsingReturning) { + const char* sql = "DELETE FROM users USING orders " + "WHERE users.id = orders.user_id RETURNING users.id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_DELETE_USING_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +TEST_F(PgSQLDeleteTest, DeleteWithAlias) { + const char* sql = "DELETE FROM users AS u WHERE u.id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct DeleteTestCase { + const char* sql; + const char* description; +}; + +static const DeleteTestCase mysql_delete_bulk_cases[] = { + {"DELETE FROM t", "simple no where"}, + {"DELETE FROM t WHERE a = 1", "simple with where"}, + {"DELETE FROM t WHERE a > 1 AND b < 10", "complex where"}, + {"DELETE FROM db.t WHERE a = 1", "qualified table"}, + {"DELETE LOW_PRIORITY FROM t WHERE a = 1", "low priority"}, + {"DELETE QUICK FROM t WHERE a = 1", "quick"}, + {"DELETE IGNORE FROM t WHERE a = 1", "ignore"}, + {"DELETE LOW_PRIORITY QUICK IGNORE FROM t WHERE a = 1", "all options"}, + {"DELETE FROM t WHERE a = 1 ORDER BY b LIMIT 10", "order by limit"}, + {"DELETE FROM t WHERE a = 1 LIMIT 100", "limit only"}, + {"DELETE t1 FROM t1 JOIN t2 ON t1.id = t2.fk WHERE t2.x = 0", "multi-table form 1"}, + {"DELETE t1, t2 FROM t1 JOIN t2 ON t1.id = t2.fk", "multi-table form 1 multi target"}, + {"DELETE FROM t1 USING t1 JOIN t2 ON t1.id = t2.fk WHERE t2.x = 0", "multi-table form 2"}, + {"DELETE FROM t1, t2 USING t1 JOIN t2 ON t1.id = t2.fk", "multi-table form 2 multi target"}, +}; + +TEST(MySQLDeleteBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_delete_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const DeleteTestCase pgsql_delete_bulk_cases[] = { + {"DELETE FROM t", "simple no where"}, + {"DELETE FROM t WHERE a = 1", "simple with where"}, + {"DELETE FROM t WHERE a > 1 AND b < 10", "complex where"}, + {"DELETE FROM t AS x WHERE x.a = 1", "alias"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk", "using single"}, + {"DELETE FROM t USING t2, t3 WHERE t.id = t2.fk AND t2.id = t3.fk", "using multi"}, + {"DELETE FROM t WHERE a = 1 RETURNING *", "returning star"}, + {"DELETE FROM t WHERE a = 1 RETURNING a, b", "returning cols"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk RETURNING t.a", "using + returning"}, +}; + +TEST(PgSQLDeleteBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_delete_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::DELETE_STMT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const DeleteTestCase mysql_delete_roundtrip_cases[] = { + {"DELETE FROM t WHERE a = 1", "simple"}, + {"DELETE LOW_PRIORITY QUICK IGNORE FROM t WHERE a = 1", "all options"}, + {"DELETE FROM t WHERE a = 1 ORDER BY b LIMIT 10", "order by limit"}, +}; + +TEST(MySQLDeleteRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_delete_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const DeleteTestCase pgsql_delete_roundtrip_cases[] = { + {"DELETE FROM t WHERE a = 1", "simple"}, + {"DELETE FROM t USING t2 WHERE t.id = t2.fk", "using"}, + {"DELETE FROM t WHERE a = 1 RETURNING *", "returning"}, +}; + +TEST(PgSQLDeleteRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_delete_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} diff --git a/tests/test_digest.cpp b/tests/test_digest.cpp new file mode 100644 index 0000000..7ce878e --- /dev/null +++ b/tests/test_digest.cpp @@ -0,0 +1,322 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/digest.h" + +using namespace sql_parser; + +// Helper struct to hold digest results as stable std::string + hash +struct StableDigest { + std::string normalized; + uint64_t hash; +}; + +class MySQLDigestTest : public ::testing::Test { +protected: + Parser parser; + + // AST-based digest (parses SQL, invalidates previous arena allocations) + StableDigest digest_ast(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + Digest digest(parser.arena()); + DigestResult dr; + if (r.ast) { + dr = digest.compute(r.ast); + } else { + dr = digest.compute(sql, strlen(sql)); + } + return StableDigest{std::string(dr.normalized.ptr, dr.normalized.len), dr.hash}; + } + + // Token-level digest (uses arena but does NOT call parse, so arena is stable + // within a single call but may be invalidated by subsequent parse calls) + StableDigest digest_token(const char* sql) { + parser.reset(); + Digest digest(parser.arena()); + auto dr = digest.compute(sql, strlen(sql)); + return StableDigest{std::string(dr.normalized.ptr, dr.normalized.len), dr.hash}; + } + + std::string normalized(const char* sql) { + return digest_ast(sql).normalized; + } + + std::string normalized_token(const char* sql) { + return digest_token(sql).normalized; + } +}; + +// ========== Literal normalization ========== + +TEST_F(MySQLDigestTest, IntegerLiteralNormalized) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE id = 42"), + "SELECT * FROM t WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, FloatLiteralNormalized) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE price > 3.14"), + "SELECT * FROM t WHERE price > ?"); +} + +TEST_F(MySQLDigestTest, StringLiteralNormalized) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE name = 'Alice'"), + "SELECT * FROM t WHERE name = ?"); +} + +TEST_F(MySQLDigestTest, MultipleLiteralsNormalized) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE a = 1 AND b = 'x' AND c = 3.14"), + "SELECT * FROM t WHERE a = ? AND b = ? AND c = ?"); +} + +// ========== Same query, different literals => same hash ========== + +TEST_F(MySQLDigestTest, SameQueryDifferentInts) { + auto d1 = digest_ast("SELECT * FROM t WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM t WHERE id = 999"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, SameQueryDifferentStrings) { + auto d1 = digest_ast("SELECT * FROM t WHERE name = 'Alice'"); + auto d2 = digest_ast("SELECT * FROM t WHERE name = 'Bob'"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, DifferentQueriesDifferentHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM t WHERE name = 1"); + EXPECT_NE(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, DifferentTablesDifferentHash) { + auto d1 = digest_ast("SELECT * FROM users WHERE id = 1"); + auto d2 = digest_ast("SELECT * FROM orders WHERE id = 1"); + EXPECT_NE(d1.hash, d2.hash); +} + +// ========== IN list collapsing ========== + +TEST_F(MySQLDigestTest, InListCollapsed) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE id IN (1, 2, 3)"), + "SELECT * FROM t WHERE id IN (?)"); +} + +TEST_F(MySQLDigestTest, InListDifferentSizesSameHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3)"); + auto d2 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)"); + EXPECT_EQ(d1.hash, d2.hash); +} + +TEST_F(MySQLDigestTest, InListSingleValueSameHash) { + auto d1 = digest_ast("SELECT * FROM t WHERE id IN (1)"); + auto d2 = digest_ast("SELECT * FROM t WHERE id IN (1, 2, 3)"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== Keyword uppercasing ========== + +TEST_F(MySQLDigestTest, KeywordsUppercased) { + EXPECT_EQ(normalized_token("select * from t where id = 1"), + "SELECT * FROM t WHERE id = ?"); +} + +// ========== Token-level fallback for Tier 2 ========== + +TEST_F(MySQLDigestTest, TokenLevelInsert) { + EXPECT_EQ(normalized_token("INSERT INTO users (name) VALUES ('Alice')"), + "INSERT INTO users (name) VALUES (?)"); +} + +TEST_F(MySQLDigestTest, TokenLevelUpdate) { + EXPECT_EQ(normalized_token("UPDATE users SET name = 'Bob' WHERE id = 42"), + "UPDATE users SET name = ? WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, TokenLevelDelete) { + EXPECT_EQ(normalized_token("DELETE FROM users WHERE id = 1"), + "DELETE FROM users WHERE id = ?"); +} + +TEST_F(MySQLDigestTest, TokenLevelCreateTable) { + EXPECT_EQ(normalized_token("CREATE TABLE t (id INT DEFAULT 0)"), + "CREATE TABLE t (id INT DEFAULT ?)"); +} + +TEST_F(MySQLDigestTest, TokenLevelInCollapsing) { + EXPECT_EQ(normalized_token("SELECT * FROM t WHERE id IN (1, 2, 3, 4, 5)"), + "SELECT * FROM t WHERE id IN (?)"); +} + +// ========== SET statement digest ========== + +TEST_F(MySQLDigestTest, SetVariableDigest) { + auto d1 = digest_ast("SET autocommit = 1"); + auto d2 = digest_ast("SET autocommit = 0"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== NULL and boolean literals ========== + +TEST_F(MySQLDigestTest, NullPreserved) { + EXPECT_EQ(normalized("SELECT * FROM t WHERE a IS NULL"), + "SELECT * FROM t WHERE a IS NULL"); +} + +TEST_F(MySQLDigestTest, LimitDigest) { + auto d1 = digest_ast("SELECT * FROM t LIMIT 10"); + auto d2 = digest_ast("SELECT * FROM t LIMIT 20"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== Placeholder passthrough ========== + +TEST_F(MySQLDigestTest, PlaceholderPassthrough) { + EXPECT_EQ(normalized_token("SELECT * FROM t WHERE id = ?"), + "SELECT * FROM t WHERE id = ?"); +} + +// ========== Hash stability ========== + +TEST_F(MySQLDigestTest, HashStability) { + auto d1 = digest_ast("SELECT * FROM users WHERE id = 1"); + EXPECT_NE(d1.hash, 0ULL); + auto d2 = digest_ast("SELECT * FROM users WHERE id = 42"); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== Consistency: AST-based and token-level produce same hash ========== + +TEST_F(MySQLDigestTest, AstAndTokenLevelConsistentForSimpleSelect) { + auto d_ast = digest_ast("SELECT * FROM users WHERE id = 42"); + auto d_tok = digest_token("SELECT * FROM users WHERE id = 42"); + EXPECT_EQ(d_ast.normalized, d_tok.normalized); + EXPECT_EQ(d_ast.hash, d_tok.hash); +} + +// ========== Bulk digest tests ========== + +struct DigestTestCase { + const char* sql1; + const char* sql2; + bool same_hash; + const char* description; +}; + +static const DigestTestCase digest_bulk_cases[] = { + {"SELECT * FROM t WHERE id = 1", "SELECT * FROM t WHERE id = 2", true, "different int literals"}, + {"SELECT * FROM t WHERE s = 'a'", "SELECT * FROM t WHERE s = 'b'", true, "different string literals"}, + {"SELECT * FROM t WHERE x = 1.5", "SELECT * FROM t WHERE x = 2.7", true, "different float literals"}, + {"SELECT * FROM t WHERE id IN (1,2)", "SELECT * FROM t WHERE id IN (1,2,3,4)", true, "in list sizes"}, + {"SELECT * FROM t LIMIT 10", "SELECT * FROM t LIMIT 100", true, "different limits"}, + {"SELECT * FROM t1 WHERE id = 1", "SELECT * FROM t2 WHERE id = 1", false, "different tables"}, + {"SELECT a FROM t WHERE id = 1", "SELECT b FROM t WHERE id = 1", false, "different columns"}, + {"SELECT * FROM t WHERE a = 1", "SELECT * FROM t WHERE b = 1", false, "different where cols"}, + {"SELECT * FROM t ORDER BY a", "SELECT * FROM t ORDER BY b", false, "different order"}, +}; + +TEST(MySQLDigestBulk, HashConsistency) { + Parser parser; + for (const auto& tc : digest_bulk_cases) { + // Parse and digest first query, copy results + auto r1 = parser.parse(tc.sql1, strlen(tc.sql1)); + Digest d1(parser.arena()); + auto dr1 = r1.ast ? d1.compute(r1.ast) : d1.compute(tc.sql1, strlen(tc.sql1)); + std::string norm1(dr1.normalized.ptr, dr1.normalized.len); + uint64_t hash1 = dr1.hash; + + // Parse and digest second query + auto r2 = parser.parse(tc.sql2, strlen(tc.sql2)); + Digest d2(parser.arena()); + auto dr2 = r2.ast ? d2.compute(r2.ast) : d2.compute(tc.sql2, strlen(tc.sql2)); + std::string norm2(dr2.normalized.ptr, dr2.normalized.len); + uint64_t hash2 = dr2.hash; + + if (tc.same_hash) { + EXPECT_EQ(hash1, hash2) + << "Expected same hash: " << tc.description + << "\n SQL1: " << tc.sql1 << "\n SQL2: " << tc.sql2 + << "\n Norm1: " << norm1 + << "\n Norm2: " << norm2; + } else { + EXPECT_NE(hash1, hash2) + << "Expected different hash: " << tc.description + << "\n SQL1: " << tc.sql1 << "\n SQL2: " << tc.sql2; + } + } +} + +// ========== INSERT digest (AST-based) ========== + +TEST_F(MySQLDigestTest, InsertDigestNormalized) { + EXPECT_EQ(normalized("INSERT INTO t (a, b) VALUES (1, 'hello')"), + "INSERT INTO t (a, b) VALUES (?, ?)"); +} + +TEST_F(MySQLDigestTest, InsertMultiRowCollapsed) { + auto d1 = digest_ast("INSERT INTO t (a) VALUES (1)"); + auto d2 = digest_ast("INSERT INTO t (a) VALUES (1), (2), (3)"); + EXPECT_EQ(d1.normalized, d2.normalized); + EXPECT_EQ(d1.hash, d2.hash); +} + +// ========== PostgreSQL digest ========== + +class PgSQLDigestTest : public ::testing::Test { +protected: + Parser parser; + + StableDigest digest_token(const char* sql) { + parser.reset(); + Digest digest(parser.arena()); + auto dr = digest.compute(sql, strlen(sql)); + return StableDigest{std::string(dr.normalized.ptr, dr.normalized.len), dr.hash}; + } + + std::string normalized_token(const char* sql) { + return digest_token(sql).normalized; + } +}; + +TEST_F(PgSQLDigestTest, BasicDigest) { + EXPECT_EQ(normalized_token("SELECT * FROM users WHERE id = 42"), + "SELECT * FROM users WHERE id = ?"); +} + +TEST_F(PgSQLDigestTest, DollarPlaceholderPreserved) { + EXPECT_EQ(normalized_token("SELECT * FROM users WHERE id = $1"), + "SELECT * FROM users WHERE id = $1"); +} + +TEST_F(PgSQLDigestTest, InListCollapsed) { + EXPECT_EQ(normalized_token("SELECT * FROM t WHERE id IN (1, 2, 3)"), + "SELECT * FROM t WHERE id IN (?)"); +} + +TEST_F(PgSQLDigestTest, ReturningDigest) { + EXPECT_EQ(normalized_token("INSERT INTO t (a) VALUES (1) RETURNING *"), + "INSERT INTO t (a) VALUES (?) RETURNING *"); +} + +// ========== Token-level digest for various Tier 2 statements ========== + +TEST_F(MySQLDigestTest, TokenLevelGrant) { + std::string out = normalized_token("GRANT SELECT ON db.* TO 'user'@'host'"); + EXPECT_EQ(out, "GRANT SELECT ON db.* TO ?@?"); +} + +TEST_F(MySQLDigestTest, TokenLevelDropTable) { + EXPECT_EQ(normalized_token("DROP TABLE IF EXISTS t"), + "DROP TABLE IF EXISTS t"); +} + +// ========== Token-level VALUES collapsing ========== + +TEST_F(MySQLDigestTest, TokenLevelValuesMultiRowCollapsed) { + EXPECT_EQ(normalized_token("INSERT INTO t (a, b) VALUES (1, 'x'), (2, 'y'), (3, 'z')"), + "INSERT INTO t (a, b) VALUES (?, ?)"); +} + +TEST_F(MySQLDigestTest, TokenLevelValuesMultiRowSameHash) { + auto d1 = digest_token("INSERT INTO t (a) VALUES (1)"); + auto d2 = digest_token("INSERT INTO t (a) VALUES (1), (2), (3)"); + EXPECT_EQ(d1.hash, d2.hash); +} diff --git a/tests/test_emitter.cpp b/tests/test_emitter.cpp new file mode 100644 index 0000000..5428b92 --- /dev/null +++ b/tests/test_emitter.cpp @@ -0,0 +1,303 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLEmitterTest : public ::testing::Test { +protected: + Parser parser; + + // Parse, emit, return the emitted string + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== SET round-trips ========== + +TEST_F(MySQLEmitterTest, SetSimpleVariable) { + std::string out = round_trip("SET autocommit = 1"); + EXPECT_EQ(out, "SET autocommit = 1"); +} + +TEST_F(MySQLEmitterTest, SetMultipleVariables) { + std::string out = round_trip("SET autocommit = 1, wait_timeout = 28800"); + EXPECT_EQ(out, "SET autocommit = 1, wait_timeout = 28800"); +} + +TEST_F(MySQLEmitterTest, SetNames) { + std::string out = round_trip("SET NAMES utf8mb4"); + EXPECT_EQ(out, "SET NAMES utf8mb4"); +} + +TEST_F(MySQLEmitterTest, SetNamesCollate) { + std::string out = round_trip("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"); + EXPECT_EQ(out, "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci"); +} + +TEST_F(MySQLEmitterTest, SetCharacterSet) { + std::string out = round_trip("SET CHARACTER SET utf8"); + EXPECT_EQ(out, "SET CHARACTER SET utf8"); +} + +TEST_F(MySQLEmitterTest, SetCharset) { + // CHARSET is normalized to CHARACTER SET in emitted output + std::string out = round_trip("SET CHARSET utf8"); + EXPECT_EQ(out, "SET CHARACTER SET utf8"); +} + +TEST_F(MySQLEmitterTest, SetGlobalVariable) { + std::string out = round_trip("SET GLOBAL max_connections = 100"); + EXPECT_EQ(out, "SET GLOBAL max_connections = 100"); +} + +TEST_F(MySQLEmitterTest, SetSessionVariable) { + std::string out = round_trip("SET SESSION wait_timeout = 600"); + EXPECT_EQ(out, "SET SESSION wait_timeout = 600"); +} + +TEST_F(MySQLEmitterTest, SetDoubleAtVariable) { + std::string out = round_trip("SET @@session.wait_timeout = 600"); + EXPECT_EQ(out, "SET @@session.wait_timeout = 600"); +} + +TEST_F(MySQLEmitterTest, SetUserVariable) { + std::string out = round_trip("SET @my_var = 42"); + EXPECT_EQ(out, "SET @my_var = 42"); +} + +TEST_F(MySQLEmitterTest, SetTransaction) { + std::string out = round_trip("SET TRANSACTION READ ONLY"); + EXPECT_EQ(out, "SET TRANSACTION READ ONLY"); +} + +TEST_F(MySQLEmitterTest, SetTransactionIsolation) { + // ISOLATION LEVEL keywords are consumed by parser; emitter outputs the level value directly + std::string out = round_trip("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + EXPECT_EQ(out, "SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + // Note: To support this, the SET parser must preserve "ISOLATION LEVEL" in the AST. + // The emitter's emit_set_transaction() must check children and re-insert the keywords. +} + +TEST_F(MySQLEmitterTest, SetFunctionRHS) { + std::string out = round_trip("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); + EXPECT_EQ(out, "SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); +} + +// ========== SELECT round-trips ========== + +TEST_F(MySQLEmitterTest, SelectLiteral) { + std::string out = round_trip("SELECT 1"); + EXPECT_EQ(out, "SELECT 1"); +} + +TEST_F(MySQLEmitterTest, SelectStar) { + std::string out = round_trip("SELECT * FROM users"); + EXPECT_EQ(out, "SELECT * FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectColumns) { + std::string out = round_trip("SELECT id, name FROM users"); + EXPECT_EQ(out, "SELECT id, name FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectWithAlias) { + std::string out = round_trip("SELECT id AS user_id FROM users"); + EXPECT_EQ(out, "SELECT id AS user_id FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectDistinct) { + std::string out = round_trip("SELECT DISTINCT name FROM users"); + EXPECT_EQ(out, "SELECT DISTINCT name FROM users"); +} + +TEST_F(MySQLEmitterTest, SelectWhere) { + std::string out = round_trip("SELECT * FROM users WHERE id = 1"); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = 1"); +} + +TEST_F(MySQLEmitterTest, SelectWhereAnd) { + std::string out = round_trip("SELECT * FROM users WHERE age > 18 AND status = 'active'"); + EXPECT_EQ(out, "SELECT * FROM users WHERE age > 18 AND status = 'active'"); +} + +TEST_F(MySQLEmitterTest, SelectJoin) { + std::string out = round_trip("SELECT * FROM users JOIN orders ON users.id = orders.user_id"); + EXPECT_EQ(out, "SELECT * FROM users JOIN orders ON users.id = orders.user_id"); +} + +TEST_F(MySQLEmitterTest, SelectLeftJoin) { + std::string out = round_trip("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id"); + EXPECT_EQ(out, "SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id"); +} + +TEST_F(MySQLEmitterTest, SelectGroupBy) { + std::string out = round_trip("SELECT status, COUNT(*) FROM users GROUP BY status"); + EXPECT_EQ(out, "SELECT status, COUNT(*) FROM users GROUP BY status"); +} + +TEST_F(MySQLEmitterTest, SelectGroupByHaving) { + std::string out = round_trip("SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5"); + EXPECT_EQ(out, "SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5"); +} + +TEST_F(MySQLEmitterTest, SelectOrderBy) { + std::string out = round_trip("SELECT * FROM users ORDER BY name ASC"); + EXPECT_EQ(out, "SELECT * FROM users ORDER BY name ASC"); +} + +TEST_F(MySQLEmitterTest, SelectLimit) { + std::string out = round_trip("SELECT * FROM users LIMIT 10"); + EXPECT_EQ(out, "SELECT * FROM users LIMIT 10"); +} + +TEST_F(MySQLEmitterTest, SelectLimitOffset) { + std::string out = round_trip("SELECT * FROM users LIMIT 10 OFFSET 20"); + EXPECT_EQ(out, "SELECT * FROM users LIMIT 10 OFFSET 20"); +} + +TEST_F(MySQLEmitterTest, SelectForUpdate) { + std::string out = round_trip("SELECT * FROM users FOR UPDATE"); + EXPECT_EQ(out, "SELECT * FROM users FOR UPDATE"); +} + +// ========== Expression round-trips ========== + +TEST_F(MySQLEmitterTest, ExprIsNull) { + std::string out = round_trip("SELECT * FROM t WHERE x IS NULL"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IS NULL"); +} + +TEST_F(MySQLEmitterTest, ExprIsNotNull) { + std::string out = round_trip("SELECT * FROM t WHERE x IS NOT NULL"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IS NOT NULL"); +} + +TEST_F(MySQLEmitterTest, ExprBetween) { + std::string out = round_trip("SELECT * FROM t WHERE x BETWEEN 1 AND 10"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x BETWEEN 1 AND 10"); +} + +TEST_F(MySQLEmitterTest, ExprIn) { + std::string out = round_trip("SELECT * FROM t WHERE x IN (1, 2, 3)"); + EXPECT_EQ(out, "SELECT * FROM t WHERE x IN (1, 2, 3)"); +} + +TEST_F(MySQLEmitterTest, ExprFunctionCall) { + std::string out = round_trip("SELECT COUNT(*) FROM users"); + EXPECT_EQ(out, "SELECT COUNT(*) FROM users"); +} + +TEST_F(MySQLEmitterTest, ExprUnaryMinus) { + std::string out = round_trip("SELECT -1"); + EXPECT_EQ(out, "SELECT -1"); +} + +// ========== Bulk round-trip tests ========== + +struct RoundTripCase { + const char* sql; + const char* description; +}; + +static const RoundTripCase roundtrip_cases[] = { + {"SET autocommit = 0", "set simple"}, + {"SET NAMES utf8", "set names"}, + {"SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci", "set names collate"}, + {"SET CHARACTER SET utf8", "set character set"}, + {"SET GLOBAL max_connections = 100", "set global"}, + {"SET @x = 42", "set user var"}, + {"SET @@session.wait_timeout = 600", "set sys var"}, + {"SELECT 1", "select literal"}, + {"SELECT * FROM t", "select star"}, + {"SELECT a, b FROM t", "select columns"}, + {"SELECT a AS x FROM t", "select alias"}, + {"SELECT DISTINCT a FROM t", "select distinct"}, + {"SELECT * FROM t WHERE a = 1", "select where"}, + {"SELECT * FROM t WHERE a > 1 AND b < 10", "select where and"}, + {"SELECT * FROM t ORDER BY a", "select order by"}, + {"SELECT * FROM t ORDER BY a DESC", "select order by desc"}, + {"SELECT * FROM t LIMIT 10", "select limit"}, + {"SELECT * FROM t LIMIT 10 OFFSET 5", "select limit offset"}, + {"SELECT * FROM t FOR UPDATE", "select for update"}, + {"SELECT COUNT(*) FROM t", "select count"}, + {"SELECT * FROM t WHERE x IS NULL", "is null"}, + {"SELECT * FROM t WHERE x IS NOT NULL", "is not null"}, + {"SELECT * FROM t WHERE x IN (1, 2, 3)", "in list"}, + {"SELECT * FROM t WHERE x BETWEEN 1 AND 10", "between"}, +}; + +TEST(MySQLEmitterBulk, RoundTripsMatch) { + Parser parser; + for (const auto& tc : roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +// ========== AST modification tests ========== + +TEST_F(MySQLEmitterTest, ModifySetValue) { + // Parse SET autocommit = 1, modify value to 0, emit + auto r = parser.parse("SET autocommit = 1", 18); + ASSERT_NE(r.ast, nullptr); + + // Navigate to the value node: SET_STMT -> VAR_ASSIGNMENT -> (target, value) + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + ASSERT_EQ(assignment->type, NodeType::NODE_VAR_ASSIGNMENT); + + // Second child of assignment is the RHS value + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* value = target->next_sibling; + ASSERT_NE(value, nullptr); + + // Modify the value + const char* new_val = "0"; + value->value_ptr = new_val; + value->value_len = 1; + + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET autocommit = 0"); +} + +// ========== PostgreSQL round-trips ========== + +TEST(PgSQLEmitterTest, SetVarTo) { + // PostgreSQL TO is normalized to = in emitted output + Parser parser; + auto r = parser.parse("SET client_encoding TO 'UTF8'", 29); + ASSERT_NE(r.ast, nullptr); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET client_encoding = 'UTF8'"); +} + +TEST(PgSQLEmitterTest, SelectBasic) { + Parser parser; + auto r = parser.parse("SELECT * FROM users WHERE id = 1", 32); + ASSERT_NE(r.ast, nullptr); + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SELECT * FROM users WHERE id = 1"); +} diff --git a/tests/test_expression.cpp b/tests/test_expression.cpp new file mode 100644 index 0000000..03adc27 --- /dev/null +++ b/tests/test_expression.cpp @@ -0,0 +1,290 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/expression_parser.h" + +using namespace sql_parser; + +// Helper: parse an expression from a SQL string using a fresh parser context. +// We use the tokenizer directly since expression parsing is an internal function. +class ExpressionTest : public ::testing::Test { +protected: + Arena arena{4096}; + Tokenizer tok; + + AstNode* parse_expr(const char* sql) { + tok.reset(sql, strlen(sql)); + ExpressionParser ep(tok, arena); + return ep.parse(); + } +}; + +// ===== Task 1: Literals and Identifiers ===== + +TEST_F(ExpressionTest, IntegerLiteral) { + AstNode* node = parse_expr("42"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "42"); +} + +TEST_F(ExpressionTest, FloatLiteral) { + AstNode* node = parse_expr("3.14"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_FLOAT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "3.14"); +} + +TEST_F(ExpressionTest, StringLiteral) { + AstNode* node = parse_expr("'hello'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_STRING); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "hello"); +} + +TEST_F(ExpressionTest, NullLiteral) { + AstNode* node = parse_expr("NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_NULL); +} + +TEST_F(ExpressionTest, TrueLiteral) { + AstNode* node = parse_expr("TRUE"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "TRUE"); +} + +TEST_F(ExpressionTest, FalseLiteral) { + AstNode* node = parse_expr("FALSE"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "FALSE"); +} + +TEST_F(ExpressionTest, SimpleIdentifier) { + AstNode* node = parse_expr("my_column"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_COLUMN_REF); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "my_column"); +} + +TEST_F(ExpressionTest, QualifiedIdentifier) { + AstNode* node = parse_expr("t.col"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_QUALIFIED_NAME); + // first child = table, second child = column + ASSERT_NE(node->first_child, nullptr); + ASSERT_NE(node->first_child->next_sibling, nullptr); +} + +TEST_F(ExpressionTest, Asterisk) { + AstNode* node = parse_expr("*"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_ASTERISK); +} + +TEST_F(ExpressionTest, Placeholder) { + AstNode* node = parse_expr("?"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_PLACEHOLDER); +} + +TEST_F(ExpressionTest, DefaultKeyword) { + AstNode* node = parse_expr("DEFAULT"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IDENTIFIER); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "DEFAULT"); +} + +TEST_F(ExpressionTest, UserVariable) { + AstNode* node = parse_expr("@my_var"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_COLUMN_REF); +} + +TEST_F(ExpressionTest, ParenthesizedExpression) { + AstNode* node = parse_expr("(42)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_LITERAL_INT); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "42"); +} + +// ===== Task 2: Binary Operators, IS NULL, BETWEEN, IN, Functions ===== + +TEST_F(ExpressionTest, BinaryAdd) { + AstNode* node = parse_expr("1 + 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "+"); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_LITERAL_INT); + ASSERT_NE(node->first_child->next_sibling, nullptr); + EXPECT_EQ(node->first_child->next_sibling->type, NodeType::NODE_LITERAL_INT); +} + +TEST_F(ExpressionTest, Precedence_MulOverAdd) { + // 1 + 2 * 3 should parse as 1 + (2 * 3) + AstNode* node = parse_expr("1 + 2 * 3"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "+"); + // Right child should be 2*3 + AstNode* right = node->first_child->next_sibling; + ASSERT_NE(right, nullptr); + EXPECT_EQ(right->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(right->value_ptr, right->value_len), "*"); +} + +TEST_F(ExpressionTest, ComparisonEqual) { + AstNode* node = parse_expr("x = 1"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "="); +} + +TEST_F(ExpressionTest, LogicalAnd) { + AstNode* node = parse_expr("a = 1 AND b = 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "AND"); +} + +TEST_F(ExpressionTest, LogicalOr) { + AstNode* node = parse_expr("a = 1 OR b = 2"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "OR"); +} + +TEST_F(ExpressionTest, UnaryMinus) { + AstNode* node = parse_expr("-42"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_LITERAL_INT); +} + +TEST_F(ExpressionTest, UnaryNot) { + AstNode* node = parse_expr("NOT x"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); +} + +TEST_F(ExpressionTest, IsNull) { + AstNode* node = parse_expr("x IS NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IS_NULL); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_COLUMN_REF); +} + +TEST_F(ExpressionTest, IsNotNull) { + AstNode* node = parse_expr("x IS NOT NULL"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IS_NOT_NULL); +} + +TEST_F(ExpressionTest, Between) { + AstNode* node = parse_expr("x BETWEEN 1 AND 10"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BETWEEN); + // 3 children: expr, low, high + ASSERT_NE(node->first_child, nullptr); + ASSERT_NE(node->first_child->next_sibling, nullptr); + ASSERT_NE(node->first_child->next_sibling->next_sibling, nullptr); +} + +TEST_F(ExpressionTest, InList) { + AstNode* node = parse_expr("x IN (1, 2, 3)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_IN_LIST); + // Children: expr, val1, val2, val3 = 4 children + int count = 0; + for (AstNode* c = node->first_child; c; c = c->next_sibling) ++count; + EXPECT_EQ(count, 4); +} + +TEST_F(ExpressionTest, FunctionCall) { + AstNode* node = parse_expr("COUNT(*)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "COUNT"); + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_ASTERISK); +} + +TEST_F(ExpressionTest, FunctionCallMultiArg) { + AstNode* node = parse_expr("COALESCE(a, b, 0)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + int count = 0; + for (AstNode* c = node->first_child; c; c = c->next_sibling) ++count; + EXPECT_EQ(count, 3); +} + +TEST_F(ExpressionTest, NestedParens) { + AstNode* node = parse_expr("(1 + 2) * 3"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "*"); + // Left child should be 1+2 + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->first_child->value_ptr, node->first_child->value_len), "+"); +} + +TEST_F(ExpressionTest, LikeOperator) { + AstNode* node = parse_expr("name LIKE '%test%'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); + EXPECT_EQ(std::string(node->value_ptr, node->value_len), "LIKE"); +} + +TEST_F(ExpressionTest, StringConcat) { + AstNode* node = parse_expr("a || b"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_BINARY_OP); +} + +TEST_F(ExpressionTest, NotIn) { + AstNode* node = parse_expr("x NOT IN (1, 2)"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps IN_LIST + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_IN_LIST); +} + +TEST_F(ExpressionTest, NotBetween) { + AstNode* node = parse_expr("x NOT BETWEEN 1 AND 10"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps BETWEEN + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BETWEEN); +} + +TEST_F(ExpressionTest, NotLike) { + AstNode* node = parse_expr("name NOT LIKE '%test'"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_UNARY_OP); // NOT wraps LIKE + ASSERT_NE(node->first_child, nullptr); + EXPECT_EQ(node->first_child->type, NodeType::NODE_BINARY_OP); +} + +TEST_F(ExpressionTest, CaseWhenSimple) { + AstNode* node = parse_expr("CASE WHEN x = 1 THEN 'a' ELSE 'b' END"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_CASE_WHEN); +} + +TEST_F(ExpressionTest, CaseWhenSearched) { + AstNode* node = parse_expr("CASE x WHEN 1 THEN 'a' WHEN 2 THEN 'b' END"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_CASE_WHEN); +} + +TEST_F(ExpressionTest, ZeroArgFunction) { + AstNode* node = parse_expr("NOW()"); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->type, NodeType::NODE_FUNCTION_CALL); + EXPECT_EQ(node->first_child, nullptr); // no args +} diff --git a/tests/test_insert.cpp b/tests/test_insert.cpp new file mode 100644 index 0000000..a68b0f3 --- /dev/null +++ b/tests/test_insert.cpp @@ -0,0 +1,419 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLInsertTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic INSERT ========== + +TEST_F(MySQLInsertTest, SimpleInsert) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice')", 49); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_INSERT_STMT); +} + +TEST_F(MySQLInsertTest, InsertWithoutInto) { + auto r = parser.parse("INSERT users (id) VALUES (1)", 28); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertWithoutColumnList) { + auto r = parser.parse("INSERT INTO users VALUES (1, 'Alice')", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* cols = find_child(r.ast, NodeType::NODE_INSERT_COLUMNS); + EXPECT_EQ(cols, nullptr); // no column list +} + +TEST_F(MySQLInsertTest, InsertColumnList) { + auto r = parser.parse("INSERT INTO users (id, name, email) VALUES (1, 'Alice', 'a@b.com')", 67); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* cols = find_child(r.ast, NodeType::NODE_INSERT_COLUMNS); + ASSERT_NE(cols, nullptr); + EXPECT_EQ(child_count(cols), 3); +} + +TEST_F(MySQLInsertTest, InsertMultiRow) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')", 60); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* values = find_child(r.ast, NodeType::NODE_VALUES_CLAUSE); + ASSERT_NE(values, nullptr); + EXPECT_EQ(child_count(values), 2); // two rows +} + +TEST_F(MySQLInsertTest, InsertTableRef) { + auto r = parser.parse("INSERT INTO mydb.users (id) VALUES (1)", 39); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* tref = find_child(r.ast, NodeType::NODE_TABLE_REF); + ASSERT_NE(tref, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLInsertTest, InsertLowPriority) { + auto r = parser.parse("INSERT LOW_PRIORITY INTO users (id) VALUES (1)", 47); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLInsertTest, InsertDelayed) { + auto r = parser.parse("INSERT DELAYED INTO users (id) VALUES (1)", 42); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertHighPriority) { + auto r = parser.parse("INSERT HIGH_PRIORITY INTO users (id) VALUES (1)", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertIgnore) { + auto r = parser.parse("INSERT IGNORE INTO users (id) VALUES (1)", 41); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLInsertTest, InsertLowPriorityIgnore) { + auto r = parser.parse("INSERT LOW_PRIORITY IGNORE INTO users (id) VALUES (1)", 54); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== INSERT ... SELECT ========== + +TEST_F(MySQLInsertTest, InsertSelect) { + auto r = parser.parse("INSERT INTO users (id, name) SELECT id, name FROM temp_users", 60); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* select = find_child(r.ast, NodeType::NODE_SELECT_STMT); + ASSERT_NE(select, nullptr); +} + +TEST_F(MySQLInsertTest, InsertSelectWithWhere) { + const char* sql = "INSERT INTO users (id, name) SELECT id, name FROM temp WHERE active = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL INSERT ... SET ========== + +TEST_F(MySQLInsertTest, InsertSet) { + auto r = parser.parse("INSERT INTO users SET id = 1, name = 'Alice'", 45); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* set_clause = find_child(r.ast, NodeType::NODE_INSERT_SET_CLAUSE); + ASSERT_NE(set_clause, nullptr); + EXPECT_EQ(child_count(set_clause), 2); // two col=val pairs +} + +// ========== ON DUPLICATE KEY UPDATE ========== + +TEST_F(MySQLInsertTest, OnDuplicateKey) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') ON DUPLICATE KEY UPDATE name = 'Alice2'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* odku = find_child(r.ast, NodeType::NODE_ON_DUPLICATE_KEY); + ASSERT_NE(odku, nullptr); +} + +TEST_F(MySQLInsertTest, OnDuplicateKeyMultiple) { + const char* sql = "INSERT INTO users (id, name, email) VALUES (1, 'Alice', 'a@b.com') " + "ON DUPLICATE KEY UPDATE name = VALUES(name), email = VALUES(email)"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* odku = find_child(r.ast, NodeType::NODE_ON_DUPLICATE_KEY); + ASSERT_NE(odku, nullptr); + EXPECT_EQ(child_count(odku), 2); +} + +// ========== REPLACE ========== + +TEST_F(MySQLInsertTest, ReplaceSimple) { + auto r = parser.parse("REPLACE INTO users (id, name) VALUES (1, 'Alice')", 50); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_INSERT_STMT); + // REPLACE flag should be set in flags + EXPECT_NE(r.ast->flags & 0x01, 0); // FLAG_REPLACE = 0x01 +} + +TEST_F(MySQLInsertTest, ReplaceLowPriority) { + auto r = parser.parse("REPLACE LOW_PRIORITY INTO users (id) VALUES (1)", 48); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::REPLACE); +} + +TEST_F(MySQLInsertTest, ReplaceDelayed) { + auto r = parser.parse("REPLACE DELAYED INTO users (id) VALUES (1)", 43); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== PostgreSQL INSERT ========== + +class PgSQLInsertTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLInsertTest, SimpleInsert) { + auto r = parser.parse("INSERT INTO users (id, name) VALUES (1, 'Alice')", 49); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::INSERT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, DefaultValues) { + auto r = parser.parse("INSERT INTO users DEFAULT VALUES", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoNothing) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT DO NOTHING"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* oc = find_child(r.ast, NodeType::NODE_ON_CONFLICT); + ASSERT_NE(oc, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoUpdate) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = 'Alice2'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* oc = find_child(r.ast, NodeType::NODE_ON_CONFLICT); + ASSERT_NE(oc, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictOnConstraint) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT ON CONSTRAINT users_pkey DO NOTHING"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictDoUpdateWhere) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name WHERE users.active = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, Returning) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLInsertTest, ReturningStar) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLInsertTest, OnConflictWithReturning) { + const char* sql = "INSERT INTO users (id, name) VALUES (1, 'Alice') " + "ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ON_CONFLICT), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct InsertTestCase { + const char* sql; + const char* description; +}; + +static const InsertTestCase mysql_insert_bulk_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple single column"}, + {"INSERT INTO t (a, b) VALUES (1, 2)", "two columns"}, + {"INSERT INTO t (a, b, c) VALUES (1, 2, 3)", "three columns"}, + {"INSERT INTO t VALUES (1, 2)", "no column list"}, + {"INSERT t (a) VALUES (1)", "without INTO"}, + {"INSERT INTO db.t (a) VALUES (1)", "qualified table"}, + {"INSERT INTO t (a) VALUES (1), (2), (3)", "multi-row"}, + {"INSERT INTO t (a, b) VALUES (1, 'x'), (2, 'y')", "multi-row with strings"}, + {"INSERT LOW_PRIORITY INTO t (a) VALUES (1)", "low priority"}, + {"INSERT DELAYED INTO t (a) VALUES (1)", "delayed"}, + {"INSERT HIGH_PRIORITY INTO t (a) VALUES (1)", "high priority"}, + {"INSERT IGNORE INTO t (a) VALUES (1)", "ignore"}, + {"INSERT LOW_PRIORITY IGNORE INTO t (a) VALUES (1)", "low priority ignore"}, + {"INSERT INTO t SET a = 1", "set form single"}, + {"INSERT INTO t SET a = 1, b = 'x'", "set form multiple"}, + {"INSERT INTO t (a) SELECT a FROM t2", "insert select"}, + {"INSERT INTO t (a, b) SELECT a, b FROM t2 WHERE c > 0", "insert select with where"}, + {"INSERT INTO t (a) VALUES (1) ON DUPLICATE KEY UPDATE a = 2", "on duplicate key"}, + {"INSERT INTO t (a, b) VALUES (1, 'x') ON DUPLICATE KEY UPDATE b = VALUES(b)", "odku values()"}, + {"INSERT INTO t (a, b) VALUES (1, 'x') ON DUPLICATE KEY UPDATE a = a + 1, b = 'y'", "odku multi"}, + {"REPLACE INTO t (a) VALUES (1)", "replace simple"}, + {"REPLACE INTO t (a, b) VALUES (1, 2)", "replace two cols"}, + {"REPLACE LOW_PRIORITY INTO t (a) VALUES (1)", "replace low priority"}, + {"REPLACE INTO t SET a = 1", "replace set form"}, +}; + +TEST(MySQLInsertBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_insert_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const InsertTestCase pgsql_insert_bulk_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t (a, b) VALUES (1, 2)", "two columns"}, + {"INSERT INTO t VALUES (1, 2)", "no column list"}, + {"INSERT INTO t DEFAULT VALUES", "default values"}, + {"INSERT INTO t (a) VALUES (1), (2)", "multi-row"}, + {"INSERT INTO t (a) SELECT a FROM t2", "insert select"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING", "on conflict do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO NOTHING", "on conflict col do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2", "on conflict do update"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT ON CONSTRAINT t_pkey DO NOTHING", "on constraint"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = EXCLUDED.a", "excluded ref"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2 WHERE t.b > 0", "do update where"}, + {"INSERT INTO t (a) VALUES (1) RETURNING a", "returning single"}, + {"INSERT INTO t (a) VALUES (1) RETURNING *", "returning star"}, + {"INSERT INTO t (a) VALUES (1) RETURNING a, b", "returning multi"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING RETURNING *", "conflict + returning"}, +}; + +TEST(PgSQLInsertBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_insert_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const InsertTestCase mysql_insert_roundtrip_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t (a, b) VALUES (1, 'x')", "two cols with string"}, + {"INSERT INTO t (a) VALUES (1), (2), (3)", "multi-row"}, + {"INSERT INTO t SET a = 1, b = 'x'", "set form"}, + {"INSERT LOW_PRIORITY IGNORE INTO t (a) VALUES (1)", "options"}, + {"INSERT INTO t (a) VALUES (1) ON DUPLICATE KEY UPDATE a = 2", "odku"}, + {"REPLACE INTO t (a, b) VALUES (1, 2)", "replace"}, +}; + +TEST(MySQLInsertRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_insert_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const InsertTestCase pgsql_insert_roundtrip_cases[] = { + {"INSERT INTO t (a) VALUES (1)", "simple"}, + {"INSERT INTO t DEFAULT VALUES", "default values"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT DO NOTHING", "on conflict do nothing"}, + {"INSERT INTO t (a) VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 2", "on conflict do update"}, + {"INSERT INTO t (a) VALUES (1) RETURNING *", "returning star"}, +}; + +TEST(PgSQLInsertRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_insert_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} diff --git a/tests/test_main.cpp b/tests/test_main.cpp new file mode 100644 index 0000000..5ebbc76 --- /dev/null +++ b/tests/test_main.cpp @@ -0,0 +1,6 @@ +#include + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/test_select.cpp b/tests/test_select.cpp new file mode 100644 index 0000000..4df7a51 --- /dev/null +++ b/tests/test_select.cpp @@ -0,0 +1,462 @@ +#include +#include "sql_parser/parser.h" + +using namespace sql_parser; + +class MySQLSelectTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } +}; + +// ========== Basic SELECT ========== + +TEST_F(MySQLSelectTest, SelectLiteral) { + auto r = parser.parse("SELECT 1", 8); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SELECT_STMT); +} + +TEST_F(MySQLSelectTest, SelectStar) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(MySQLSelectTest, SelectColumns) { + auto r = parser.parse("SELECT id, name, email FROM users", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + EXPECT_EQ(child_count(items), 3); +} + +TEST_F(MySQLSelectTest, SelectWithAlias) { + auto r = parser.parse("SELECT id AS user_id, name AS user_name FROM users", 50); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* items = find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST); + ASSERT_NE(items, nullptr); + // Each item should have an alias child + auto* first_item = items->first_child; + ASSERT_NE(first_item, nullptr); + EXPECT_EQ(first_item->type, NodeType::NODE_SELECT_ITEM); + auto* alias = find_child(first_item, NodeType::NODE_ALIAS); + ASSERT_NE(alias, nullptr); +} + +TEST_F(MySQLSelectTest, SelectImplicitAlias) { + // Alias without AS keyword + auto r = parser.parse("SELECT id user_id FROM users", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectDistinct) { + auto r = parser.parse("SELECT DISTINCT name FROM users", 31); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_SELECT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLSelectTest, SelectSqlCalcFoundRows) { + auto r = parser.parse("SELECT SQL_CALC_FOUND_ROWS * FROM users", 40); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromQualifiedTable) { + auto r = parser.parse("SELECT * FROM mydb.users", 24); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromTableAlias) { + auto r = parser.parse("SELECT u.id FROM users u", 24); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromTableAsAlias) { + auto r = parser.parse("SELECT u.id FROM users AS u", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectFromMultipleTables) { + auto r = parser.parse("SELECT * FROM users, orders", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); + EXPECT_GE(child_count(from), 2); +} + +TEST_F(MySQLSelectTest, SelectExpression) { + auto r = parser.parse("SELECT 1 + 2, 'hello', NOW()", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, SelectNoFrom) { + auto r = parser.parse("SELECT 1, 'a', NOW()", 20); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + // No FROM clause + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + EXPECT_EQ(from, nullptr); +} + +// ========== JOINs ========== + +TEST_F(MySQLSelectTest, InnerJoin) { + auto r = parser.parse("SELECT * FROM users INNER JOIN orders ON users.id = orders.user_id", 66); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(MySQLSelectTest, LeftJoin) { + auto r = parser.parse("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id", 65); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSelectTest, RightJoin) { + auto r = parser.parse("SELECT * FROM users RIGHT JOIN orders ON users.id = orders.user_id", 66); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, LeftOuterJoin) { + auto r = parser.parse("SELECT * FROM users LEFT OUTER JOIN orders ON users.id = orders.user_id", 71); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, CrossJoin) { + auto r = parser.parse("SELECT * FROM users CROSS JOIN orders", 37); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, NaturalJoin) { + auto r = parser.parse("SELECT * FROM users NATURAL JOIN orders", 39); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, JoinUsing) { + auto r = parser.parse("SELECT * FROM users JOIN orders USING (user_id)", 48); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, MultipleJoins) { + const char* sql = "SELECT * FROM users JOIN orders ON users.id = orders.user_id JOIN items ON orders.id = items.order_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, JoinWithAlias) { + auto r = parser.parse("SELECT * FROM users u JOIN orders o ON u.id = o.user_id", 55); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== WHERE ========== + +TEST_F(MySQLSelectTest, WhereSimple) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + ASSERT_NE(where, nullptr); +} + +TEST_F(MySQLSelectTest, WhereComplex) { + auto r = parser.parse("SELECT * FROM users WHERE age > 18 AND status = 'active'", 56); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereIn) { + auto r = parser.parse("SELECT * FROM users WHERE id IN (1, 2, 3)", 42); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereBetween) { + auto r = parser.parse("SELECT * FROM users WHERE age BETWEEN 18 AND 65", 48); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereLike) { + auto r = parser.parse("SELECT * FROM users WHERE name LIKE '%john%'", 44); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereIsNull) { + auto r = parser.parse("SELECT * FROM users WHERE email IS NULL", 39); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, WhereSubquery) { + auto r = parser.parse("SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)", 60); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== GROUP BY / HAVING ========== + +TEST_F(MySQLSelectTest, GroupBy) { + auto r = parser.parse("SELECT status, COUNT(*) FROM users GROUP BY status", 51); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* gb = find_child(r.ast, NodeType::NODE_GROUP_BY_CLAUSE); + ASSERT_NE(gb, nullptr); +} + +TEST_F(MySQLSelectTest, GroupByMultiple) { + auto r = parser.parse("SELECT dept, status, COUNT(*) FROM users GROUP BY dept, status", 62); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, GroupByHaving) { + auto r = parser.parse("SELECT status, COUNT(*) FROM users GROUP BY status HAVING COUNT(*) > 5", 71); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* having = find_child(r.ast, NodeType::NODE_HAVING_CLAUSE); + ASSERT_NE(having, nullptr); +} + +// ========== ORDER BY ========== + +TEST_F(MySQLSelectTest, OrderBy) { + auto r = parser.parse("SELECT * FROM users ORDER BY name", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ob = find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE); + ASSERT_NE(ob, nullptr); +} + +TEST_F(MySQLSelectTest, OrderByDesc) { + auto r = parser.parse("SELECT * FROM users ORDER BY created_at DESC", 45); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, OrderByMultiple) { + auto r = parser.parse("SELECT * FROM users ORDER BY last_name ASC, first_name ASC", 58); + EXPECT_EQ(r.status, ParseResult::OK); + auto* ob = find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE); + ASSERT_NE(ob, nullptr); + EXPECT_EQ(child_count(ob), 2); +} + +// ========== LIMIT ========== + +TEST_F(MySQLSelectTest, Limit) { + auto r = parser.parse("SELECT * FROM users LIMIT 10", 28); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); +} + +TEST_F(MySQLSelectTest, LimitOffset) { + auto r = parser.parse("SELECT * FROM users LIMIT 10 OFFSET 20", 38); + EXPECT_EQ(r.status, ParseResult::OK); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); + EXPECT_EQ(child_count(limit), 2); +} + +TEST_F(MySQLSelectTest, LimitCommaOffset) { + // MySQL syntax: LIMIT offset, count + auto r = parser.parse("SELECT * FROM users LIMIT 20, 10", 32); + EXPECT_EQ(r.status, ParseResult::OK); + auto* limit = find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE); + ASSERT_NE(limit, nullptr); + EXPECT_EQ(child_count(limit), 2); +} + +// ========== FOR UPDATE / FOR SHARE ========== + +TEST_F(MySQLSelectTest, ForUpdate) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE", 44); + EXPECT_EQ(r.status, ParseResult::OK); + auto* lock = find_child(r.ast, NodeType::NODE_LOCKING_CLAUSE); + ASSERT_NE(lock, nullptr); +} + +TEST_F(MySQLSelectTest, ForShare) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR SHARE", 43); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, ForUpdateNowait) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE NOWAIT", 51); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, ForUpdateSkipLocked) { + auto r = parser.parse("SELECT * FROM users WHERE id = 1 FOR UPDATE SKIP LOCKED", 56); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== Complex queries ========== + +TEST_F(MySQLSelectTest, FullQuery) { + const char* sql = "SELECT u.id, u.name, COUNT(o.id) AS order_count " + "FROM users u " + "LEFT JOIN orders o ON u.id = o.user_id " + "WHERE u.status = 'active' " + "GROUP BY u.id, u.name " + "HAVING COUNT(o.id) > 5 " + "ORDER BY order_count DESC " + "LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_SELECT_ITEM_LIST), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_FROM_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_WHERE_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_GROUP_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_HAVING_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLSelectTest, SubqueryInFrom) { + const char* sql = "SELECT t.id FROM (SELECT id FROM users WHERE active = 1) AS t"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(MySQLSelectTest, MultiStatement) { + const char* sql = "SELECT 1; SELECT 2"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + EXPECT_TRUE(r.has_remaining()); +} + +TEST_F(MySQLSelectTest, SelectWithSemicolon) { + auto r = parser.parse("SELECT * FROM users;", 20); + EXPECT_EQ(r.status, ParseResult::OK); +} + +// ========== Bulk data-driven tests ========== + +struct SelectTestCase { + const char* sql; + const char* description; +}; + +static const SelectTestCase select_bulk_cases[] = { + {"SELECT 1", "literal"}, + {"SELECT 1, 2, 3", "multiple literals"}, + {"SELECT 'hello'", "string literal"}, + {"SELECT NULL", "null"}, + {"SELECT TRUE", "true"}, + {"SELECT FALSE", "false"}, + {"SELECT NOW()", "function call"}, + {"SELECT 1 + 2", "arithmetic"}, + {"SELECT *", "star"}, + {"SELECT * FROM t", "star from table"}, + {"SELECT a FROM t", "single column"}, + {"SELECT a, b, c FROM t", "multiple columns"}, + {"SELECT a AS x FROM t", "alias with AS"}, + {"SELECT t.a FROM t", "qualified column"}, + {"SELECT t.* FROM t", "qualified star"}, + {"SELECT DISTINCT a FROM t", "distinct"}, + {"SELECT ALL a FROM t", "all"}, + {"SELECT SQL_CALC_FOUND_ROWS * FROM t", "sql_calc_found_rows"}, + {"SELECT * FROM db.t", "qualified table"}, + {"SELECT * FROM t AS alias", "table alias with AS"}, + {"SELECT * FROM t alias", "table alias implicit"}, + {"SELECT * FROM t1, t2", "comma join"}, + {"SELECT * FROM t1 JOIN t2 ON t1.id = t2.id", "inner join"}, + {"SELECT * FROM t1 LEFT JOIN t2 ON t1.id = t2.id", "left join"}, + {"SELECT * FROM t1 RIGHT JOIN t2 ON t1.id = t2.id", "right join"}, + {"SELECT * FROM t1 CROSS JOIN t2", "cross join"}, + {"SELECT * FROM t1 NATURAL JOIN t2", "natural join"}, + {"SELECT * FROM t1 JOIN t2 USING (id)", "join using"}, + {"SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.id = t2.id", "left outer join"}, + {"SELECT * FROM t WHERE a = 1", "where equal"}, + {"SELECT * FROM t WHERE a > 1 AND b < 10", "where and"}, + {"SELECT * FROM t WHERE a IN (1,2,3)", "where in"}, + {"SELECT * FROM t WHERE a IS NULL", "where is null"}, + {"SELECT * FROM t WHERE a IS NOT NULL", "where is not null"}, + {"SELECT * FROM t WHERE a BETWEEN 1 AND 10", "where between"}, + {"SELECT * FROM t WHERE a LIKE '%x%'", "where like"}, + {"SELECT * FROM t WHERE a NOT IN (1,2)", "where not in"}, + {"SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)", "where exists"}, + {"SELECT a, COUNT(*) FROM t GROUP BY a", "group by"}, + {"SELECT a, b, COUNT(*) FROM t GROUP BY a, b", "group by multiple"}, + {"SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1", "having"}, + {"SELECT * FROM t ORDER BY a", "order by"}, + {"SELECT * FROM t ORDER BY a DESC", "order by desc"}, + {"SELECT * FROM t ORDER BY a ASC, b DESC", "order by multiple"}, + {"SELECT * FROM t LIMIT 10", "limit"}, + {"SELECT * FROM t LIMIT 10 OFFSET 5", "limit offset"}, + {"SELECT * FROM t LIMIT 5, 10", "limit comma"}, + {"SELECT * FROM t WHERE a = 1 FOR UPDATE", "for update"}, + {"SELECT * FROM t WHERE a = 1 FOR SHARE", "for share"}, + {"SELECT * FROM t FOR UPDATE NOWAIT", "for update nowait"}, + {"SELECT * FROM t FOR UPDATE SKIP LOCKED", "for update skip locked"}, + {"SELECT COUNT(*), SUM(a), AVG(b), MIN(c), MAX(d) FROM t", "aggregate functions"}, + {"SELECT CASE WHEN a = 1 THEN 'x' ELSE 'y' END FROM t", "case when"}, + {"SELECT * FROM (SELECT 1) AS t", "subquery in from"}, + {"SELECT * FROM t1 JOIN t2 ON t1.a = t2.a JOIN t3 ON t2.b = t3.b", "multiple joins"}, + {"SELECT a FROM t WHERE b = (SELECT MAX(b) FROM t2)", "scalar subquery in where"}, +}; + +TEST(MySQLSelectBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : select_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SELECT) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== PostgreSQL SELECT ========== + +class PgSQLSelectTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLSelectTest, BasicSelect) { + auto r = parser.parse("SELECT * FROM users", 19); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSelectTest, LimitOffset) { + auto r = parser.parse("SELECT * FROM users LIMIT 10 OFFSET 5", 37); + EXPECT_EQ(r.status, ParseResult::OK); +} + +TEST_F(PgSQLSelectTest, ForUpdate) { + auto r = parser.parse("SELECT * FROM users FOR UPDATE", 30); + EXPECT_EQ(r.status, ParseResult::OK); +} diff --git a/tests/test_set.cpp b/tests/test_set.cpp new file mode 100644 index 0000000..5c8e423 --- /dev/null +++ b/tests/test_set.cpp @@ -0,0 +1,645 @@ +#include +#include "sql_parser/parser.h" +#include + +using namespace sql_parser; + +// ============================================================================ +// Data-driven test infrastructure +// ============================================================================ + +struct SetTestCase { + const char* sql; + const char* description; +}; + +// ============================================================================ +// MySQL SET bulk test cases — all should parse successfully +// ============================================================================ + +static const SetTestCase mysql_set_cases[] = { + // --- sql_mode variants --- + {"SET @@sql_mode = 'TRADITIONAL'", "sql_mode with @@ prefix"}, + {"SET SESSION sql_mode = 'TRADITIONAL'", "sql_mode with SESSION scope"}, + {"SET @@session.sql_mode = 'TRADITIONAL'", "sql_mode with @@session. prefix"}, + {"SET @@local.sql_mode = 'TRADITIONAL'", "sql_mode with @@local. prefix"}, + {"SET sql_mode = 'TRADITIONAL'", "sql_mode unqualified"}, + {"SET SQL_MODE ='TRADITIONAL'", "sql_mode with extra spaces before ="}, + {"SET SQL_MODE = \"TRADITIONAL\"", "sql_mode with double-quoted value"}, + {"SET SQL_MODE = TRADITIONAL", "sql_mode with unquoted value"}, + {"set sql_mode = IFNULL(NULL,\"STRICT_TRANS_TABLES\")", "sql_mode with IFNULL function call"}, + {"set sql_mode = IFNULL(NULL,'STRICT_TRANS_TABLES')", "sql_mode with IFNULL single-quoted"}, + {"SET @@SESSION.sql_mode = CONCAT(CONCAT(@@sql_mode, ', STRICT_ALL_TABLES'), ', NO_AUTO_VALUE_ON_ZERO')", "sql_mode with nested CONCAT"}, + {"SET @@LOCAL.sql_mode = CONCAT(CONCAT(@@sql_mode, ', STRICT_ALL_TABLES'), ', NO_AUTO_VALUE_ON_ZERO')", "sql_mode with nested CONCAT via LOCAL"}, + {"set session sql_mode = 'ONLY_FULL_GROUP_BY'", "sql_mode lowercase session"}, + {"SET sql_mode = 'NO_ZERO_DATE,STRICT_ALL_TABLES,ONLY_FULL_GROUP_BY'", "sql_mode comma-separated modes in string"}, + {"SET @@sql_mode = CONCAT(@@sql_mode, ',', 'ONLY_FULL_GROUP_BY')", "sql_mode CONCAT with 3 args"}, + {"SET @@sql_mode = REPLACE(REPLACE(REPLACE(@@sql_mode, 'ONLY_FULL_GROUP_BY,', ''),',ONLY_FULL_GROUP_BY', ''),'ONLY_FULL_GROUP_BY', '')", "sql_mode deeply nested REPLACE"}, + {"SET @@sql_mode = REPLACE( REPLACE( REPLACE( @@sql_mode, 'ONLY_FULL_GROUP_BY,', ''),',ONLY_FULL_GROUP_BY', ''),'ONLY_FULL_GROUP_BY', '')", "sql_mode deeply nested REPLACE with spaces"}, + {"SET SQL_MODE=IFNULL(@@sql_mode,'')", "sql_mode IFNULL with @@sysvar no spaces"}, + {"SET SQL_MODE=IFNULL(@old_sql_mode,'')", "sql_mode IFNULL with user variable"}, + {"SET SQL_MODE=IFNULL(@OLD_SQL_MODE,'')", "sql_mode IFNULL with uppercase user variable"}, + {"SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION'))", "sql_mode with subquery"}, + {"SET sql_mode=''", "sql_mode empty string"}, + + // --- time_zone variants --- + {"SET @@time_zone = 'Europe/Paris'", "time_zone 2-component"}, + {"SET @@time_zone = '+00:00'", "time_zone numeric offset"}, + {"SET @@time_zone = \"Europe/Paris\"", "time_zone double-quoted"}, + {"SET @@time_zone = \"+00:00\"", "time_zone numeric offset double-quoted"}, + {"SET @@TIME_ZONE := 'SYSTEM'", "time_zone colon-equal"}, + {"SET time_zone := 'SYSTEM'", "time_zone unqualified colon-equal"}, + {"SET time_zone = 'UTC'", "time_zone UTC"}, + {"SET time_zone = SYSTEM", "time_zone unquoted SYSTEM"}, + {"SET time_zone = UTC", "time_zone unquoted UTC"}, + {"SET time_zone = 'America/Argentina/Buenos_Aires'", "time_zone 3-component"}, + {"SET time_zone = 'America/Indiana/Indianapolis'", "time_zone 3-component Indianapolis"}, + {"SET time_zone = \"America/Kentucky/Louisville\"", "time_zone 3-component double-quoted"}, + {"SET time_zone = 'America/Port-au-Prince'", "time_zone with hyphens"}, + {"SET time_zone = 'America/Blanc-Sablon'", "time_zone with hyphen"}, + {"SET time_zone = \"US/East-Indiana\"", "time_zone with hyphen double-quoted"}, + {"SET time_zone = '+08:00'", "time_zone +08:00"}, + {"SET time_zone = '-05:30'", "time_zone -05:30"}, + {"SET time_zone = '-10:00'", "time_zone -10:00"}, + {"SET @@time_zone = @OLD_TIME_ZONE", "time_zone user variable RHS"}, + {"SET @@TIME_ZONE = @OLD_TIME_ZONE", "time_zone uppercase user variable RHS"}, + + // --- NAMES / CHARSET --- + {"SET NAMES utf8", "NAMES unquoted"}, + {"SET NAMES 'utf8'", "NAMES single-quoted"}, + {"SET NAMES \"utf8\"", "NAMES double-quoted"}, + {"SET NAMES utf8 COLLATE unicode_ci", "NAMES with COLLATE"}, + {"SET NAMES DEFAULT", "NAMES DEFAULT"}, + + // --- CHARACTER SET / CHARSET --- + {"SET CHARACTER SET utf8", "CHARACTER SET unquoted"}, + {"SET CHARACTER SET 'utf8'", "CHARACTER SET single-quoted"}, + {"SET CHARSET utf8", "CHARSET unquoted"}, + {"SET CHARSET 'latin1'", "CHARSET single-quoted"}, + + // --- Session/Global scope --- + {"SET @@SESSION.SQL_SELECT_LIMIT= DEFAULT", "SESSION SQL_SELECT_LIMIT DEFAULT"}, + {"SET @@LOCAL.SQL_SELECT_LIMIT= DEFAULT", "LOCAL SQL_SELECT_LIMIT DEFAULT"}, + {"SET @@SQL_SELECT_LIMIT= DEFAULT", "@@ SQL_SELECT_LIMIT DEFAULT"}, + {"SET SESSION SQL_SELECT_LIMIT = DEFAULT", "SESSION keyword SQL_SELECT_LIMIT DEFAULT"}, + {"SET @@SESSION.SQL_SELECT_LIMIT= 1234", "SESSION SQL_SELECT_LIMIT number"}, + {"SET @@LOCAL.SQL_SELECT_LIMIT= 1234", "LOCAL SQL_SELECT_LIMIT number"}, + {"SET @@SQL_SELECT_LIMIT= 1234", "@@ SQL_SELECT_LIMIT number"}, + {"SET SESSION SQL_SELECT_LIMIT = 1234", "SESSION keyword SQL_SELECT_LIMIT number"}, + {"SET @@SESSION.SQL_SELECT_LIMIT= @old_sql_select_limit", "SESSION SQL_SELECT_LIMIT user var"}, + {"SET @@LOCAL.SQL_SELECT_LIMIT= @old_sql_select_limit", "LOCAL SQL_SELECT_LIMIT user var"}, + {"SET SQL_SELECT_LIMIT= @old_sql_select_limit", "SQL_SELECT_LIMIT user var"}, + {"SET GLOBAL max_connections = 100", "GLOBAL max_connections"}, + {"SET @@SESSION.sql_auto_is_null = 0", "SESSION sql_auto_is_null"}, + {"SET @@LOCAL.sql_auto_is_null = 0", "LOCAL sql_auto_is_null"}, + {"SET SESSION sql_auto_is_null = 1", "SESSION keyword sql_auto_is_null"}, + {"SET sql_auto_is_null = OFF", "sql_auto_is_null OFF"}, + {"SET @@sql_auto_is_null = ON", "sql_auto_is_null ON"}, + {"SET @@SESSION.sql_safe_updates = 0", "SESSION sql_safe_updates"}, + {"SET @@LOCAL.sql_safe_updates = 0", "LOCAL sql_safe_updates"}, + {"SET SESSION sql_safe_updates = 1", "SESSION keyword sql_safe_updates"}, + {"SET SQL_SAFE_UPDATES = OFF", "SQL_SAFE_UPDATES OFF"}, + {"SET @@sql_safe_updates = ON", "sql_safe_updates ON"}, + + // --- session_track_gtids --- + {"SET @@session_track_gtids = OFF", "session_track_gtids OFF"}, + {"SET @@session_track_gtids = OWN_GTID", "session_track_gtids OWN_GTID"}, + {"SET @@SESSION.session_track_gtids = OWN_GTID", "SESSION session_track_gtids OWN_GTID"}, + {"SET @@LOCAL.session_track_gtids = OWN_GTID", "LOCAL session_track_gtids OWN_GTID"}, + {"SET SESSION session_track_gtids = OWN_GTID", "SESSION keyword session_track_gtids OWN_GTID"}, + {"SET @@session_track_gtids = ALL_GTIDS", "session_track_gtids ALL_GTIDS"}, + + // --- character_set_results --- + {"SET @@character_set_results = utf8", "character_set_results utf8"}, + {"SET @@character_set_results = NULL", "character_set_results NULL"}, + {"SET character_set_results = NULL", "character_set_results NULL unqualified"}, + {"SET @@session.character_set_results = NULL", "session.character_set_results NULL"}, + {"SET @@local.character_set_results = NULL", "local.character_set_results NULL"}, + {"SET session character_set_results = NULL", "session keyword character_set_results NULL"}, + + // --- Transaction --- + {"SET session transaction read only", "SESSION TRANSACTION READ ONLY"}, + {"SET session transaction read write", "SESSION TRANSACTION READ WRITE"}, + {"SET session transaction isolation level READ COMMITTED", "SESSION TRANSACTION READ COMMITTED"}, + {"SET session transaction isolation level READ UNCOMMITTED", "SESSION TRANSACTION READ UNCOMMITTED"}, + {"SET session transaction isolation level REPEATABLE READ", "SESSION TRANSACTION REPEATABLE READ"}, + {"SET session transaction isolation level SERIALIZABLE", "SESSION TRANSACTION SERIALIZABLE"}, + {"SET TRANSACTION READ ONLY", "TRANSACTION READ ONLY no scope"}, + {"SET TRANSACTION READ WRITE", "TRANSACTION READ WRITE no scope"}, + {"SET TRANSACTION ISOLATION LEVEL REPEATABLE READ", "TRANSACTION ISOLATION LEVEL REPEATABLE READ"}, + {"SET GLOBAL TRANSACTION READ WRITE", "GLOBAL TRANSACTION READ WRITE"}, + + // --- Multiple variables (comma-separated) --- + {"SET time_zone = 'Europe/Paris', sql_mode = 'TRADITIONAL'", "multi: timezone + sql_mode"}, + {"SET time_zone = 'Europe/Paris', sql_mode = IFNULL(NULL,\"STRICT_TRANS_TABLES\")", "multi: timezone + sql_mode IFNULL"}, + {"SET time_zone = 'America/Argentina/Buenos_Aires', sql_mode = 'TRADITIONAL'", "multi: 3-component timezone + sql_mode"}, + {"SET @@SESSION.sql_mode = CONCAT(CONCAT(@@sql_mode, ',STRICT_ALL_TABLES'), ',NO_AUTO_VALUE_ON_ZERO'), @@SESSION.sql_auto_is_null = 0, @@SESSION.wait_timeout = 2147483", "multi: 3 session vars with nested CONCAT"}, + {"set autocommit=1, sql_mode = concat(@@sql_mode,',STRICT_TRANS_TABLES')", "multi: autocommit + sql_mode concat"}, + {"SET autocommit = 1, wait_timeout = 28800", "multi: autocommit + wait_timeout"}, + {"SET character_set_connection=utf8,character_set_results=utf8,character_set_client=binary", "multi: 3 charset vars unquoted no spaces"}, + + // --- User variables --- + {"SET @my_var = 42", "user variable numeric"}, + {"SET @old_sql_mode = 'TRADITIONAL'", "user variable string"}, + {"SET @x = 1 + 2", "user variable expression"}, + {"SET @x := 42", "user variable colon-equal"}, + + // --- Optimizer switch --- + {"SET optimizer_switch='index_merge=on,index_merge_union=off'", "optimizer_switch single-quoted"}, + + // --- Multi-statement (semicolon) --- + {"SET autocommit = 0; BEGIN", "multi-statement SET + BEGIN"}, + + // --- Special RHS values --- + {"SET character_set_results = NULL", "NULL RHS"}, + {"SET sql_log_bin=1", "sql_log_bin 1"}, + {"SET sql_log_bin=0", "sql_log_bin 0"}, + {"SET wait_timeout = 28800", "large number RHS"}, + {"SET max_join_size=18446744073709551615", "uint64 max RHS"}, +}; + +// Cases involving NAMES in the middle of multi-SET (comma-separated). +// The current set_parser.h dispatches NAMES only at the top level, +// so these are expected to fail parsing correctly. +static const SetTestCase mysql_names_in_multi_set_cases[] = { + {"SET sql_mode = 'TRADITIONAL', NAMES 'utf8' COLLATE 'unicode_ci'", "multi: sql_mode + NAMES COLLATE"}, + {"SET NAMES utf8, @@SESSION.sql_mode = CONCAT(REPLACE(REPLACE(REPLACE(@@sql_mode, 'STRICT_TRANS_TABLES', ''), 'STRICT_ALL_TABLES', ''), 'TRADITIONAL', ''), ',NO_AUTO_VALUE_ON_ZERO'), @@SESSION.sql_auto_is_null = 0, @@SESSION.wait_timeout = 3600", "multi: NAMES + 3 assignments"}, + {"set autocommit=1, session_track_schema=1, sql_mode = concat(@@sql_mode,',STRICT_TRANS_TABLES'), @@SESSION.net_write_timeout=7200", "multi: 4 vars including session"}, + {"SET character_set_results=NULL, NAMES latin7, character_set_client='utf8mb4'", "multi: NULL + NAMES in middle + assignment"}, + {"SET character_set_results=NULL,NAMES latin7,character_set_client='utf8mb4'", "multi: NULL + NAMES in middle no spaces"}, + {"set character_set_results=null, names latin7, character_set_client='utf8mb4'", "multi: lowercase null + names in middle"}, + {"set character_set_results=null,names latin7,character_set_client='utf8mb4'", "multi: lowercase null + names no spaces"}, + {"SET @@autocommit := 0 , NAMES \"utf8mb3\"", "multi: colon-equal + NAMES"}, + {"SET character_set_results=NULL,NAMES latin7,character_set_client='utf8mb4', autocommit := 1 , time_zone = 'Europe/Paris'", "multi: 5 vars with NAMES in middle"}, +}; + +// Cases involving backtick-quoted variable names. +static const SetTestCase mysql_backtick_cases[] = { + {"SET `group_concat_max_len`=4096", "backtick-quoted variable name"}, + {"SET `sql_select_limit`=3030", "backtick-quoted sql_select_limit"}, + {"SET `tx_isolation`='READ-COMMITTED', `group_concat_max_len`=4096", "backtick-quoted multi-var"}, +}; + +// Cases involving backtick-quoted values. +static const SetTestCase mysql_backtick_value_cases[] = { + {"SET optimizer_switch=`index_merge=OFF`", "backtick-quoted value for optimizer_switch"}, +}; + +// Multi-statement with multiple SETs separated by semicolons. +static const SetTestCase mysql_multi_statement_cases[] = { + {"SET sql_select_limit=3030, session_track_gtids=OWN_GTID; SET max_join_size=10000;", "multi-statement double SET"}, +}; + +// ============================================================================ +// PostgreSQL SET test cases +// ============================================================================ + +static const SetTestCase pgsql_set_cases[] = { + {"SET client_encoding TO 'UTF8'", "PG client_encoding TO"}, + {"SET work_mem = '256MB'", "PG work_mem"}, + {"SET LOCAL timezone = 'UTC'", "PG LOCAL timezone"}, + {"SET NAMES 'UTF8'", "PG NAMES"}, + {"SET search_path TO public, extensions", "PG search_path TO list"}, +}; + +// ============================================================================ +// MySQL bulk test: all standard cases parse successfully +// ============================================================================ + +TEST(MySQLSetBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_set_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SET) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ============================================================================ +// MySQL: NAMES in multi-SET (may fail if parser doesn't support it) +// ============================================================================ + +TEST(MySQLSetBulk, NamesInMultiSetCases) { + Parser parser; + int pass_count = 0; + int fail_count = 0; + for (const auto& tc : mysql_names_in_multi_set_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + if (r.status == ParseResult::OK && r.ast != nullptr && r.stmt_type == StmtType::SET) { + pass_count++; + } else { + fail_count++; + // Use non-fatal to report but not block + ADD_FAILURE() << "NAMES-in-multi-SET not supported: " << tc.description + << "\n SQL: " << tc.sql; + } + } + std::cout << "[ INFO ] NAMES-in-multi-SET: " << pass_count << " passed, " + << fail_count << " failed (parser gap)" << std::endl; +} + +// ============================================================================ +// MySQL: backtick-quoted variable names +// ============================================================================ + +TEST(MySQLSetBulk, BacktickQuotedVariableNames) { + Parser parser; + for (const auto& tc : mysql_backtick_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SET) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ============================================================================ +// MySQL: backtick-quoted values +// ============================================================================ + +TEST(MySQLSetBulk, BacktickQuotedValues) { + Parser parser; + for (const auto& tc : mysql_backtick_value_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SET) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ============================================================================ +// MySQL: multi-statement SETs +// ============================================================================ + +TEST(MySQLSetBulk, MultiStatementSets) { + Parser parser; + for (const auto& tc : mysql_multi_statement_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SET) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_TRUE(r.has_remaining()) + << "Expected remaining SQL: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ============================================================================ +// PostgreSQL bulk test +// ============================================================================ + +TEST(PgSQLSetBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_set_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::SET) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ============================================================================ +// Individual structural tests — MySQL +// ============================================================================ + +class MySQLSetTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } +}; + +TEST_F(MySQLSetTest, SetSimpleVariable) { + auto r = parser.parse("SET autocommit = 1", 18); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SET_STMT); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_VAR_ASSIGNMENT); +} + +TEST_F(MySQLSetTest, SetMultipleVariables) { + auto r = parser.parse("SET autocommit = 1, wait_timeout = 28800", 41); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(child_count(r.ast), 2); +} + +TEST_F(MySQLSetTest, SetThreeVariables) { + const char* sql = "SET character_set_connection=utf8,character_set_results=utf8,character_set_client=binary"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(child_count(r.ast), 3); +} + +TEST_F(MySQLSetTest, SetGlobalVariable) { + auto r = parser.parse("SET GLOBAL max_connections = 100", 31); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + EXPECT_EQ(target->type, NodeType::NODE_VAR_TARGET); +} + +TEST_F(MySQLSetTest, SetSessionVariable) { + auto r = parser.parse("SET SESSION wait_timeout = 600", 30); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetDoubleAtVariable) { + auto r = parser.parse("SET @@session.wait_timeout = 600", 32); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetUserVariable) { + auto r = parser.parse("SET @my_var = 42", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetNames) { + auto r = parser.parse("SET NAMES utf8mb4", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_SET_STMT); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_NAMES); +} + +TEST_F(MySQLSetTest, SetNamesCollate) { + auto r = parser.parse("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci", 44); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_NAMES); + EXPECT_EQ(child_count(r.ast->first_child), 2); +} + +TEST_F(MySQLSetTest, SetCharacterSet) { + auto r = parser.parse("SET CHARACTER SET utf8", 21); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_CHARSET); +} + +TEST_F(MySQLSetTest, SetCharset) { + auto r = parser.parse("SET CHARSET utf8", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_CHARSET); +} + +TEST_F(MySQLSetTest, SetTransaction) { + auto r = parser.parse("SET TRANSACTION READ ONLY", 25); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_TRANSACTION); +} + +TEST_F(MySQLSetTest, SetTransactionIsolation) { + auto r = parser.parse("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetGlobalTransaction) { + auto r = parser.parse("SET GLOBAL TRANSACTION READ WRITE", 33); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetSessionTransactionReadOnly) { + const char* sql = "SET session transaction read only"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_TRANSACTION); +} + +TEST_F(MySQLSetTest, SetSessionTransactionReadWrite) { + const char* sql = "SET session transaction read write"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + ASSERT_NE(r.ast->first_child, nullptr); + EXPECT_EQ(r.ast->first_child->type, NodeType::NODE_SET_TRANSACTION); +} + +TEST_F(MySQLSetTest, SetTransactionIsolationReadCommitted) { + const char* sql = "SET session transaction isolation level READ COMMITTED"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetTransactionIsolationReadUncommitted) { + const char* sql = "SET session transaction isolation level READ UNCOMMITTED"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetTransactionIsolationSerializable) { + const char* sql = "SET session transaction isolation level SERIALIZABLE"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetExpressionRHS) { + auto r = parser.parse("SET @x = 1 + 2", 14); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetColonEqual) { + auto r = parser.parse("SET @x := 42", 12); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetNamesDefault) { + auto r = parser.parse("SET NAMES DEFAULT", 17); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetWithSemicolon) { + const char* sql = "SET autocommit = 0; BEGIN"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + EXPECT_TRUE(r.has_remaining()); +} + +TEST_F(MySQLSetTest, SetSubqueryRHS) { + const char* sql = "SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION'))"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + // The assignment should have a subquery node in its RHS + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + EXPECT_EQ(assignment->type, NodeType::NODE_VAR_ASSIGNMENT); +} + +TEST_F(MySQLSetTest, SetFunctionCallRHS) { + const char* sql = "set sql_mode = IFNULL(NULL,'STRICT_TRANS_TABLES')"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + // RHS should be a function call node (second child of assignment) + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_FUNCTION_CALL); +} + +TEST_F(MySQLSetTest, SetNestedFunctionCallRHS) { + const char* sql = "SET @@SESSION.sql_mode = CONCAT(CONCAT(@@sql_mode, ', STRICT_ALL_TABLES'), ', NO_AUTO_VALUE_ON_ZERO')"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_FUNCTION_CALL); +} + +TEST_F(MySQLSetTest, SetDeeplyNestedReplace) { + const char* sql = "SET @@sql_mode = REPLACE(REPLACE(REPLACE(@@sql_mode, 'ONLY_FULL_GROUP_BY,', ''),',ONLY_FULL_GROUP_BY', ''),'ONLY_FULL_GROUP_BY', '')"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLSetTest, SetNullRHS) { + const char* sql = "SET character_set_results = NULL"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_LITERAL_NULL); +} + +TEST_F(MySQLSetTest, SetEmptyStringRHS) { + const char* sql = "SET sql_mode=''"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_LITERAL_STRING); +} + +TEST_F(MySQLSetTest, SetDefaultRHS) { + const char* sql = "SET @@SESSION.SQL_SELECT_LIMIT= DEFAULT"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_IDENTIFIER); +} + +TEST_F(MySQLSetTest, SetUserVariableRHS) { + const char* sql = "SET @@time_zone = @OLD_TIME_ZONE"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + AstNode* assignment = r.ast->first_child; + ASSERT_NE(assignment, nullptr); + AstNode* target = assignment->first_child; + ASSERT_NE(target, nullptr); + AstNode* rhs = target->next_sibling; + ASSERT_NE(rhs, nullptr); + EXPECT_EQ(rhs->type, NodeType::NODE_COLUMN_REF); +} + +// ============================================================================ +// Individual structural tests — PostgreSQL +// ============================================================================ + +class PgSQLSetTest : public ::testing::Test { +protected: + Parser parser; +}; + +TEST_F(PgSQLSetTest, SetVarToValue) { + auto r = parser.parse("SET client_encoding TO 'UTF8'", 29); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetVarEqualValue) { + auto r = parser.parse("SET work_mem = '256MB'", 22); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetLocalVar) { + auto r = parser.parse("SET LOCAL timezone = 'UTC'", 25); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetNamesPostgres) { + auto r = parser.parse("SET NAMES 'UTF8'", 16); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLSetTest, SetSearchPathToList) { + const char* sql = "SET search_path TO public, extensions"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} diff --git a/tests/test_stmt_cache.cpp b/tests/test_stmt_cache.cpp new file mode 100644 index 0000000..83ee569 --- /dev/null +++ b/tests/test_stmt_cache.cpp @@ -0,0 +1,265 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/stmt_cache.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +// ========== StmtCache unit tests ========== + +TEST(StmtCacheTest, StoreAndLookup) { + StmtCache cache(16); + Arena arena(4096); + + AstNode* node = make_node(arena, NodeType::NODE_SET_STMT, StringRef{"SET", 3}); + ASSERT_NE(node, nullptr); + + EXPECT_TRUE(cache.store(1, StmtType::SET, node)); + EXPECT_EQ(cache.size(), 1u); + + const CachedStmt* found = cache.lookup(1); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->stmt_id, 1u); + EXPECT_EQ(found->stmt_type, StmtType::SET); + ASSERT_NE(found->ast, nullptr); + EXPECT_EQ(found->ast->type, NodeType::NODE_SET_STMT); +} + +TEST(StmtCacheTest, LookupMiss) { + StmtCache cache(16); + EXPECT_EQ(cache.lookup(999), nullptr); +} + +TEST(StmtCacheTest, Evict) { + StmtCache cache(16); + Arena arena(4096); + + AstNode* node = make_node(arena, NodeType::NODE_SELECT_STMT); + cache.store(1, StmtType::SELECT, node); + EXPECT_EQ(cache.size(), 1u); + + cache.evict(1); + EXPECT_EQ(cache.size(), 0u); + EXPECT_EQ(cache.lookup(1), nullptr); +} + +TEST(StmtCacheTest, LRUEviction) { + StmtCache cache(2); // capacity = 2 + Arena arena(4096); + + AstNode* n1 = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* n2 = make_node(arena, NodeType::NODE_SELECT_STMT); + AstNode* n3 = make_node(arena, NodeType::NODE_SET_STMT); + + cache.store(1, StmtType::SET, n1); + cache.store(2, StmtType::SELECT, n2); + EXPECT_EQ(cache.size(), 2u); + + // Adding a third should evict the LRU (stmt 1) + cache.store(3, StmtType::SET, n3); + EXPECT_EQ(cache.size(), 2u); + EXPECT_EQ(cache.lookup(1), nullptr); // evicted + EXPECT_NE(cache.lookup(2), nullptr); // still there + EXPECT_NE(cache.lookup(3), nullptr); // just added +} + +TEST(StmtCacheTest, LRUTouchOnLookup) { + StmtCache cache(2); + Arena arena(4096); + + AstNode* n1 = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* n2 = make_node(arena, NodeType::NODE_SELECT_STMT); + AstNode* n3 = make_node(arena, NodeType::NODE_SET_STMT); + + cache.store(1, StmtType::SET, n1); + cache.store(2, StmtType::SELECT, n2); + + // Touch stmt 1 to make it recently used + cache.lookup(1); + + // Adding stmt 3 should evict stmt 2 (now the LRU) + cache.store(3, StmtType::SET, n3); + EXPECT_NE(cache.lookup(1), nullptr); // touched, still alive + EXPECT_EQ(cache.lookup(2), nullptr); // evicted + EXPECT_NE(cache.lookup(3), nullptr); +} + +TEST(StmtCacheTest, DeepCopyPreservesTree) { + Arena arena(4096); + + // Build a small tree: SET_STMT -> VAR_ASSIGNMENT -> (VAR_TARGET, LITERAL_INT) + AstNode* root = make_node(arena, NodeType::NODE_SET_STMT); + AstNode* assign = make_node(arena, NodeType::NODE_VAR_ASSIGNMENT); + AstNode* target = make_node(arena, NodeType::NODE_VAR_TARGET); + target->add_child(make_node(arena, NodeType::NODE_IDENTIFIER, StringRef{"autocommit", 10})); + AstNode* value = make_node(arena, NodeType::NODE_LITERAL_INT, StringRef{"1", 1}); + assign->add_child(target); + assign->add_child(value); + root->add_child(assign); + + // Deep copy + AstNode* copy = deep_copy_ast(root); + ASSERT_NE(copy, nullptr); + EXPECT_EQ(copy->type, NodeType::NODE_SET_STMT); + + // Verify tree structure is preserved + ASSERT_NE(copy->first_child, nullptr); + EXPECT_EQ(copy->first_child->type, NodeType::NODE_VAR_ASSIGNMENT); + + AstNode* copy_target = copy->first_child->first_child; + ASSERT_NE(copy_target, nullptr); + EXPECT_EQ(copy_target->type, NodeType::NODE_VAR_TARGET); + + AstNode* copy_name = copy_target->first_child; + ASSERT_NE(copy_name, nullptr); + EXPECT_EQ(std::string(copy_name->value_ptr, copy_name->value_len), "autocommit"); + + // Verify it's a deep copy (different pointers) + EXPECT_NE(copy, root); + EXPECT_NE(copy->first_child, root->first_child); + EXPECT_NE(copy_name->value_ptr, target->first_child->value_ptr); + + // Reset arena -- copy should still be valid + arena.reset(); + EXPECT_EQ(std::string(copy_name->value_ptr, copy_name->value_len), "autocommit"); + + free_ast(copy); +} + +// ========== Parser integration tests ========== + +TEST(PreparedStmtTest, ParseAndCache) { + Parser parser; + + auto r = parser.parse_and_cache("SELECT * FROM users WHERE id = ?", 32, 1); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SELECT); + ASSERT_NE(r.ast, nullptr); +} + +TEST(PreparedStmtTest, ExecuteAfterCache) { + Parser parser; + + parser.parse_and_cache("SET autocommit = ?", 18, 42); + + // Build bindings + BoundValue bv; + bv.type = BoundValue::INT; + bv.int_val = 0; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(42, bindings); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::SET); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.bindings.count, 1); + EXPECT_EQ(r.bindings.values[0].int_val, 0); +} + +TEST(PreparedStmtTest, ExecuteNotFound) { + Parser parser; + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(999, bindings); + EXPECT_EQ(r.status, ParseResult::ERROR); +} + +TEST(PreparedStmtTest, EvictAndExecuteFails) { + Parser parser; + + parser.parse_and_cache("SELECT 1", 8, 10); + parser.prepare_cache_evict(10); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(10, bindings); + EXPECT_EQ(r.status, ParseResult::ERROR); +} + +TEST(PreparedStmtTest, CacheMultipleStatements) { + Parser parser; + + parser.parse_and_cache("SELECT 1", 8, 1); + parser.parse_and_cache("SELECT 2", 8, 2); + parser.parse_and_cache("SET autocommit = 0", 18, 3); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 0}; + + auto r1 = parser.execute(1, bindings); + EXPECT_EQ(r1.status, ParseResult::OK); + EXPECT_EQ(r1.stmt_type, StmtType::SELECT); + + auto r3 = parser.execute(3, bindings); + EXPECT_EQ(r3.status, ParseResult::OK); + EXPECT_EQ(r3.stmt_type, StmtType::SET); +} + +// ========== Emitter with bindings ========== + +TEST(PreparedStmtTest, EmitWithBindings) { + Parser parser; + + parser.parse_and_cache("SET autocommit = ?", 18, 1); + + BoundValue bv; + bv.type = BoundValue::INT; + bv.int_val = 1; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(1, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + + Emitter emitter(parser.arena(), EmitMode::NORMAL, &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET autocommit = 1"); +} + +TEST(PreparedStmtTest, EmitWithStringBinding) { + Parser parser; + + parser.parse_and_cache("SET sql_mode = ?", 16, 2); + + const char* mode = "TRADITIONAL"; + BoundValue bv; + bv.type = BoundValue::STRING; + bv.str_val = StringRef{mode, 11}; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(2, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + + Emitter emitter(parser.arena(), EmitMode::NORMAL, &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET sql_mode = 'TRADITIONAL'"); +} + +TEST(PreparedStmtTest, EmitWithNullBinding) { + Parser parser; + + parser.parse_and_cache("SET character_set_results = ?", 29, 3); + + BoundValue bv; + bv.type = BoundValue::NULL_VAL; + ParamBindings bindings{&bv, 1}; + + auto r = parser.execute(3, bindings); + ASSERT_EQ(r.status, ParseResult::OK); + + Emitter emitter(parser.arena(), EmitMode::NORMAL, &r.bindings); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, "SET character_set_results = NULL"); +} diff --git a/tests/test_tokenizer.cpp b/tests/test_tokenizer.cpp new file mode 100644 index 0000000..0ea194e --- /dev/null +++ b/tests/test_tokenizer.cpp @@ -0,0 +1,243 @@ +#include +#include "sql_parser/tokenizer.h" + +using namespace sql_parser; + +// ========== MySQL Tokenizer Tests ========== + +class MySQLTokenizerTest : public ::testing::Test { +protected: + Tokenizer tok; +}; + +TEST_F(MySQLTokenizerTest, SimpleSelect) { + const char* sql = "SELECT * FROM users;"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_SELECT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_ASTERISK); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_FROM); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(t.text.len, 5u); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "users"); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_SEMICOLON); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, CaseInsensitiveKeywords) { + const char* sql = "select FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, BacktickIdentifier) { + const char* sql = "`my table`"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "my table"); +} + +TEST_F(MySQLTokenizerTest, SingleQuotedString) { + const char* sql = "'hello world'"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_STRING); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "hello world"); +} + +TEST_F(MySQLTokenizerTest, IntegerLiteral) { + const char* sql = "42"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_INTEGER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "42"); +} + +TEST_F(MySQLTokenizerTest, FloatLiteral) { + const char* sql = "3.14"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_FLOAT); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "3.14"); +} + +TEST_F(MySQLTokenizerTest, ComparisonOperators) { + const char* sql = "= != < > <= >="; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_NOT_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_LESS); + EXPECT_EQ(tok.next_token().type, TokenType::TK_GREATER); + EXPECT_EQ(tok.next_token().type, TokenType::TK_LESS_EQUAL); + EXPECT_EQ(tok.next_token().type, TokenType::TK_GREATER_EQUAL); +} + +TEST_F(MySQLTokenizerTest, DiamondNotEqual) { + const char* sql = "<>"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_NOT_EQUAL); +} + +TEST_F(MySQLTokenizerTest, AtVariables) { + const char* sql = "@myvar @@global_var"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_AT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOUBLE_AT); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); +} + +TEST_F(MySQLTokenizerTest, Placeholder) { + const char* sql = "?"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_QUESTION); +} + +TEST_F(MySQLTokenizerTest, ColonEqual) { + const char* sql = ":="; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_COLON_EQUAL); +} + +TEST_F(MySQLTokenizerTest, LineComment) { + const char* sql = "SELECT -- this is a comment\nFROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, HashComment) { + const char* sql = "SELECT # comment\nFROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, BlockComment) { + const char* sql = "SELECT /* comment */ FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, PeekDoesNotConsume) { + const char* sql = "SELECT FROM"; + tok.reset(sql, strlen(sql)); + + Token peeked = tok.peek(); + EXPECT_EQ(peeked.type, TokenType::TK_SELECT); + + Token consumed = tok.next_token(); + EXPECT_EQ(consumed.type, TokenType::TK_SELECT); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(MySQLTokenizerTest, EmptyInput) { + tok.reset("", 0); + EXPECT_EQ(tok.next_token().type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, WhitespaceOnly) { + const char* sql = " \t\n\r "; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_EOF); +} + +TEST_F(MySQLTokenizerTest, QualifiedIdentifier) { + const char* sql = "myschema.orders"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_IDENTIFIER); // myschema + EXPECT_EQ(tok.next_token().type, TokenType::TK_DOT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_IDENTIFIER); // orders +} + +// ========== PostgreSQL Tokenizer Tests ========== + +class PgSQLTokenizerTest : public ::testing::Test { +protected: + Tokenizer tok; +}; + +TEST_F(PgSQLTokenizerTest, DoubleQuotedIdentifier) { + const char* sql = "\"my table\""; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_IDENTIFIER); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "my table"); +} + +TEST_F(PgSQLTokenizerTest, DollarQuotedString) { + const char* sql = "$$hello world$$"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_STRING); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "hello world"); +} + +TEST_F(PgSQLTokenizerTest, DoubleColonCast) { + const char* sql = "::"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_DOUBLE_COLON); +} + +TEST_F(PgSQLTokenizerTest, PositionalParam) { + const char* sql = "$1 $23"; + tok.reset(sql, strlen(sql)); + + Token t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOLLAR_NUM); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "$1"); + + t = tok.next_token(); + EXPECT_EQ(t.type, TokenType::TK_DOLLAR_NUM); + EXPECT_EQ(std::string(t.text.ptr, t.text.len), "$23"); +} + +TEST_F(PgSQLTokenizerTest, NestedBlockComment) { + const char* sql = "SELECT /* outer /* inner */ still comment */ FROM"; + tok.reset(sql, strlen(sql)); + + EXPECT_EQ(tok.next_token().type, TokenType::TK_SELECT); + EXPECT_EQ(tok.next_token().type, TokenType::TK_FROM); +} + +TEST_F(PgSQLTokenizerTest, NoHashComment) { + // PostgreSQL does NOT support # comments — # should be TK_HASH token + const char* sql = "#"; + tok.reset(sql, strlen(sql)); + EXPECT_EQ(tok.next_token().type, TokenType::TK_HASH); +} diff --git a/tests/test_update.cpp b/tests/test_update.cpp new file mode 100644 index 0000000..9b71dc9 --- /dev/null +++ b/tests/test_update.cpp @@ -0,0 +1,330 @@ +#include +#include "sql_parser/parser.h" +#include "sql_parser/emitter.h" + +using namespace sql_parser; + +class MySQLUpdateTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +// ========== Basic UPDATE ========== + +TEST_F(MySQLUpdateTest, SimpleUpdate) { + auto r = parser.parse("UPDATE users SET name = 'Alice' WHERE id = 1", 45); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + ASSERT_NE(r.ast, nullptr); + EXPECT_EQ(r.ast->type, NodeType::NODE_UPDATE_STMT); +} + +TEST_F(MySQLUpdateTest, UpdateMultipleColumns) { + const char* sql = "UPDATE users SET name = 'Alice', email = 'a@b.com' WHERE id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* set_clause = find_child(r.ast, NodeType::NODE_UPDATE_SET_CLAUSE); + ASSERT_NE(set_clause, nullptr); + EXPECT_EQ(child_count(set_clause), 2); +} + +TEST_F(MySQLUpdateTest, UpdateNoWhere) { + auto r = parser.parse("UPDATE users SET active = 0", 27); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* where = find_child(r.ast, NodeType::NODE_WHERE_CLAUSE); + EXPECT_EQ(where, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateQualifiedTable) { + auto r = parser.parse("UPDATE mydb.users SET name = 'x' WHERE id = 1", 46); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL Options ========== + +TEST_F(MySQLUpdateTest, UpdateLowPriority) { + auto r = parser.parse("UPDATE LOW_PRIORITY users SET name = 'x' WHERE id = 1", 54); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* opts = find_child(r.ast, NodeType::NODE_STMT_OPTIONS); + ASSERT_NE(opts, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateIgnore) { + auto r = parser.parse("UPDATE IGNORE users SET name = 'x' WHERE id = 1", 48); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateLowPriorityIgnore) { + auto r = parser.parse("UPDATE LOW_PRIORITY IGNORE users SET name = 'x' WHERE id = 1", 61); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== MySQL ORDER BY + LIMIT ========== + +TEST_F(MySQLUpdateTest, UpdateOrderByLimit) { + const char* sql = "UPDATE users SET rank = rank + 1 WHERE active = 1 ORDER BY score DESC LIMIT 10"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_ORDER_BY_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +TEST_F(MySQLUpdateTest, UpdateLimit) { + auto r = parser.parse("UPDATE users SET active = 0 LIMIT 100", 37); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_LIMIT_CLAUSE), nullptr); +} + +// ========== MySQL Multi-Table UPDATE ========== + +TEST_F(MySQLUpdateTest, MultiTableJoin) { + const char* sql = "UPDATE users u JOIN orders o ON u.id = o.user_id " + "SET u.total = u.total + o.amount WHERE o.status = 'shipped'"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, MultiTableCommaJoin) { + const char* sql = "UPDATE users, orders SET users.total = orders.amount " + "WHERE users.id = orders.user_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(MySQLUpdateTest, MultiTableLeftJoin) { + const char* sql = "UPDATE users u LEFT JOIN orders o ON u.id = o.user_id " + "SET u.has_orders = 0 WHERE o.id IS NULL"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== PostgreSQL UPDATE ========== + +class PgSQLUpdateTest : public ::testing::Test { +protected: + Parser parser; + + int child_count(const AstNode* node) { + int n = 0; + for (const AstNode* c = node->first_child; c; c = c->next_sibling) ++n; + return n; + } + + const AstNode* find_child(const AstNode* node, NodeType type) { + for (const AstNode* c = node->first_child; c; c = c->next_sibling) { + if (c->type == type) return c; + } + return nullptr; + } + + std::string round_trip(const char* sql) { + auto r = parser.parse(sql, strlen(sql)); + if (!r.ast) return "[PARSE_FAILED]"; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + return std::string(result.ptr, result.len); + } +}; + +TEST_F(PgSQLUpdateTest, SimpleUpdate) { + auto r = parser.parse("UPDATE users SET name = 'Alice' WHERE id = 1", 45); + EXPECT_EQ(r.status, ParseResult::OK); + EXPECT_EQ(r.stmt_type, StmtType::UPDATE); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFrom) { + const char* sql = "UPDATE users SET total = orders.amount FROM orders WHERE users.id = orders.user_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* from = find_child(r.ast, NodeType::NODE_FROM_CLAUSE); + ASSERT_NE(from, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFromMultipleTables) { + const char* sql = "UPDATE users SET total = o.amount " + "FROM orders o, payments p " + "WHERE users.id = o.user_id AND o.id = p.order_id"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateReturning) { + const char* sql = "UPDATE users SET name = 'Alice' WHERE id = 1 RETURNING id, name"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + auto* ret = find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE); + ASSERT_NE(ret, nullptr); + EXPECT_EQ(child_count(ret), 2); +} + +TEST_F(PgSQLUpdateTest, UpdateReturningStar) { + const char* sql = "UPDATE users SET name = 'Alice' WHERE id = 1 RETURNING *"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateFromReturning) { + const char* sql = "UPDATE users SET total = o.amount FROM orders o " + "WHERE users.id = o.user_id RETURNING users.id, users.total"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_FROM_CLAUSE), nullptr); + EXPECT_NE(find_child(r.ast, NodeType::NODE_RETURNING_CLAUSE), nullptr); +} + +TEST_F(PgSQLUpdateTest, UpdateWithAlias) { + const char* sql = "UPDATE users AS u SET name = 'Alice' WHERE u.id = 1"; + auto r = parser.parse(sql, strlen(sql)); + EXPECT_EQ(r.status, ParseResult::OK); + ASSERT_NE(r.ast, nullptr); +} + +// ========== Bulk data-driven tests ========== + +struct UpdateTestCase { + const char* sql; + const char* description; +}; + +static const UpdateTestCase mysql_update_bulk_cases[] = { + {"UPDATE t SET a = 1", "simple no where"}, + {"UPDATE t SET a = 1 WHERE b = 2", "simple with where"}, + {"UPDATE t SET a = 1, b = 2 WHERE c = 3", "multi column"}, + {"UPDATE t SET a = a + 1 WHERE b > 0", "expression value"}, + {"UPDATE t SET a = 'hello' WHERE b = 1", "string value"}, + {"UPDATE db.t SET a = 1", "qualified table"}, + {"UPDATE LOW_PRIORITY t SET a = 1", "low priority"}, + {"UPDATE IGNORE t SET a = 1", "ignore"}, + {"UPDATE LOW_PRIORITY IGNORE t SET a = 1", "low priority ignore"}, + {"UPDATE t SET a = 1 ORDER BY b LIMIT 10", "order by limit"}, + {"UPDATE t SET a = 1 LIMIT 100", "limit only"}, + {"UPDATE t1 JOIN t2 ON t1.id = t2.fk SET t1.a = t2.b", "join update"}, + {"UPDATE t1, t2 SET t1.a = t2.b WHERE t1.id = t2.fk", "comma join update"}, + {"UPDATE t1 LEFT JOIN t2 ON t1.id = t2.fk SET t1.a = 0 WHERE t2.id IS NULL", "left join"}, + {"UPDATE t SET a = NOW()", "function in value"}, + {"UPDATE t SET a = NULL WHERE b = 1", "set null"}, + {"UPDATE t SET a = CASE WHEN b > 0 THEN 1 ELSE 0 END", "case expression"}, +}; + +TEST(MySQLUpdateBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : mysql_update_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::UPDATE) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +static const UpdateTestCase pgsql_update_bulk_cases[] = { + {"UPDATE t SET a = 1", "simple no where"}, + {"UPDATE t SET a = 1 WHERE b = 2", "simple with where"}, + {"UPDATE t SET a = 1, b = 2 WHERE c = 3", "multi column"}, + {"UPDATE t AS x SET a = 1 WHERE x.b = 2", "alias"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk", "from clause"}, + {"UPDATE t SET a = t2.b FROM t2, t3 WHERE t.id = t2.fk AND t2.id = t3.fk", "from multi"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING *", "returning star"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING a, b", "returning cols"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk RETURNING t.a", "from + returning"}, +}; + +TEST(PgSQLUpdateBulk, AllCasesParseSuccessfully) { + Parser parser; + for (const auto& tc : pgsql_update_bulk_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + EXPECT_EQ(r.status, ParseResult::OK) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + EXPECT_EQ(r.stmt_type, StmtType::UPDATE) + << "Failed: " << tc.description; + EXPECT_NE(r.ast, nullptr) + << "Failed: " << tc.description << "\n SQL: " << tc.sql; + } +} + +// ========== Round-trip tests ========== + +static const UpdateTestCase mysql_update_roundtrip_cases[] = { + {"UPDATE t SET a = 1 WHERE b = 2", "simple"}, + {"UPDATE t SET a = 1, b = 'x' WHERE c = 3", "multi col"}, + {"UPDATE LOW_PRIORITY IGNORE t SET a = 1", "options"}, + {"UPDATE t SET a = 1 ORDER BY b DESC LIMIT 10", "order by limit"}, +}; + +TEST(MySQLUpdateRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : mysql_update_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} + +static const UpdateTestCase pgsql_update_roundtrip_cases[] = { + {"UPDATE t SET a = 1 WHERE b = 2", "simple"}, + {"UPDATE t SET a = 1 FROM t2 WHERE t.id = t2.fk", "from clause"}, + {"UPDATE t SET a = 1 WHERE b = 2 RETURNING *", "returning"}, +}; + +TEST(PgSQLUpdateRoundTrip, AllCasesRoundTrip) { + Parser parser; + for (const auto& tc : pgsql_update_roundtrip_cases) { + auto r = parser.parse(tc.sql, strlen(tc.sql)); + ASSERT_NE(r.ast, nullptr) + << "Parse failed: " << tc.description << "\n SQL: " << tc.sql; + Emitter emitter(parser.arena()); + emitter.emit(r.ast); + StringRef result = emitter.result(); + std::string out(result.ptr, result.len); + EXPECT_EQ(out, std::string(tc.sql)) + << "Round-trip mismatch: " << tc.description; + } +} diff --git a/third_party/benchmark/.clang-format b/third_party/benchmark/.clang-format new file mode 100644 index 0000000..e7d00fe --- /dev/null +++ b/third_party/benchmark/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +PointerAlignment: Left +... diff --git a/third_party/benchmark/.clang-tidy b/third_party/benchmark/.clang-tidy new file mode 100644 index 0000000..1e229e5 --- /dev/null +++ b/third_party/benchmark/.clang-tidy @@ -0,0 +1,6 @@ +--- +Checks: 'clang-analyzer-*,readability-redundant-*,performance-*' +WarningsAsErrors: 'clang-analyzer-*,readability-redundant-*,performance-*' +HeaderFilterRegex: '.*' +FormatStyle: none +User: user diff --git a/third_party/benchmark/.github/ISSUE_TEMPLATE/bug_report.md b/third_party/benchmark/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..6c2ced9 --- /dev/null +++ b/third_party/benchmark/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Create a report to help us improve +title: "[BUG]" +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**System** +Which OS, compiler, and compiler version are you using: + - OS: + - Compiler and version: + +**To reproduce** +Steps to reproduce the behavior: +1. sync to commit ... +2. cmake/bazel... +3. make ... +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Additional context** +Add any other context about the problem here. diff --git a/third_party/benchmark/.github/ISSUE_TEMPLATE/feature_request.md b/third_party/benchmark/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..9e8ab6a --- /dev/null +++ b/third_party/benchmark/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[FR]" +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/third_party/benchmark/.github/install_bazel.sh b/third_party/benchmark/.github/install_bazel.sh new file mode 100644 index 0000000..1b0d63c --- /dev/null +++ b/third_party/benchmark/.github/install_bazel.sh @@ -0,0 +1,12 @@ +if ! bazel version; then + arch=$(uname -m) + if [ "$arch" == "aarch64" ]; then + arch="arm64" + fi + echo "Downloading $arch Bazel binary from GitHub releases." + curl -L -o $HOME/bin/bazel --create-dirs "https://github.com/bazelbuild/bazel/releases/download/7.1.1/bazel-7.1.1-linux-$arch" + chmod +x $HOME/bin/bazel +else + # Bazel is installed for the correct architecture + exit 0 +fi diff --git a/third_party/benchmark/.github/libcxx-setup.sh b/third_party/benchmark/.github/libcxx-setup.sh new file mode 100755 index 0000000..9aaf96a --- /dev/null +++ b/third_party/benchmark/.github/libcxx-setup.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -e + +# Checkout LLVM sources +git clone --depth=1 --branch llvmorg-16.0.6 https://github.com/llvm/llvm-project.git llvm-project + +## Setup libc++ options +if [ -z "$BUILD_32_BITS" ]; then + export BUILD_32_BITS=OFF && echo disabling 32 bit build +fi + +## Build and install libc++ (Use unstable ABI for better sanitizer coverage) +mkdir llvm-build && cd llvm-build +cmake -DCMAKE_C_COMPILER=${CC} \ + -DCMAKE_CXX_COMPILER=${CXX} \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_INSTALL_PREFIX=/usr \ + -DLIBCXX_ABI_UNSTABLE=OFF \ + -DLLVM_USE_SANITIZER=${LIBCXX_SANITIZER} \ + -DLLVM_BUILD_32_BITS=${BUILD_32_BITS} \ + -DLLVM_ENABLE_RUNTIMES='libcxx;libcxxabi;libunwind' \ + -G "Unix Makefiles" \ + ../llvm-project/runtimes/ +make -j cxx cxxabi unwind +cd .. diff --git a/third_party/benchmark/.github/workflows/bazel.yml b/third_party/benchmark/.github/workflows/bazel.yml new file mode 100644 index 0000000..b50a8f6 --- /dev/null +++ b/third_party/benchmark/.github/workflows/bazel.yml @@ -0,0 +1,35 @@ +name: bazel + +on: + push: {} + pull_request: {} + +jobs: + build_and_test_default: + name: bazel.${{ matrix.os }}.${{ matrix.bzlmod && 'bzlmod' || 'no_bzlmod' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + bzlmod: [false, true] + steps: + - uses: actions/checkout@v4 + + - name: mount bazel cache + uses: actions/cache@v4 + env: + cache-name: bazel-cache + with: + path: "~/.cache/bazel" + key: ${{ env.cache-name }}-${{ matrix.os }}-${{ github.ref }} + restore-keys: | + ${{ env.cache-name }}-${{ matrix.os }}-main + + - name: build + run: | + bazel build ${{ matrix.bzlmod && '--enable_bzlmod' || '--noenable_bzlmod' }} //:benchmark //:benchmark_main //test/... + + - name: test + run: | + bazel test ${{ matrix.bzlmod && '--enable_bzlmod' || '--noenable_bzlmod' }} --test_output=all //test/... diff --git a/third_party/benchmark/.github/workflows/build-and-test-min-cmake.yml b/third_party/benchmark/.github/workflows/build-and-test-min-cmake.yml new file mode 100644 index 0000000..2509984 --- /dev/null +++ b/third_party/benchmark/.github/workflows/build-and-test-min-cmake.yml @@ -0,0 +1,46 @@ +name: build-and-test-min-cmake + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + job: + name: ${{ matrix.os }}.min-cmake + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v4 + + - uses: lukka/get-cmake@latest + with: + cmakeVersion: 3.13.0 + + - name: create build environment + run: cmake -E make_directory ${{ runner.workspace }}/_build + + - name: setup cmake initial cache + run: touch compiler-cache.cmake + + - name: configure cmake + env: + CXX: ${{ matrix.compiler }} + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: > + cmake -C ${{ github.workspace }}/compiler-cache.cmake + $GITHUB_WORKSPACE + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_VISIBILITY_INLINES_HIDDEN=ON + + - name: build + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: cmake --build . diff --git a/third_party/benchmark/.github/workflows/build-and-test-perfcounters.yml b/third_party/benchmark/.github/workflows/build-and-test-perfcounters.yml new file mode 100644 index 0000000..319d42d --- /dev/null +++ b/third_party/benchmark/.github/workflows/build-and-test-perfcounters.yml @@ -0,0 +1,51 @@ +name: build-and-test-perfcounters + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + job: + # TODO(dominic): Extend this to include compiler and set through env: CC/CXX. + name: ${{ matrix.os }}.${{ matrix.build_type }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-20.04] + build_type: ['Release', 'Debug'] + steps: + - uses: actions/checkout@v4 + + - name: install libpfm + run: | + sudo apt update + sudo apt -y install libpfm4-dev + + - name: create build environment + run: cmake -E make_directory ${{ runner.workspace }}/_build + + - name: configure cmake + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: > + cmake $GITHUB_WORKSPACE + -DBENCHMARK_ENABLE_LIBPFM=1 + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + + - name: build + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: cmake --build . --config ${{ matrix.build_type }} + + # Skip testing, for now. It seems perf_event_open does not succeed on the + # hosting machine, very likely a permissions issue. + # TODO(mtrofin): Enable test. + # - name: test + # shell: bash + # working-directory: ${{ runner.workspace }}/_build + # run: ctest -C ${{ matrix.build_type }} --rerun-failed --output-on-failure + diff --git a/third_party/benchmark/.github/workflows/build-and-test.yml b/third_party/benchmark/.github/workflows/build-and-test.yml new file mode 100644 index 0000000..d05300d --- /dev/null +++ b/third_party/benchmark/.github/workflows/build-and-test.yml @@ -0,0 +1,161 @@ +name: build-and-test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + # TODO: add 32-bit builds (g++ and clang++) for ubuntu + # (requires g++-multilib and libc6:i386) + # TODO: add coverage build (requires lcov) + # TODO: add clang + libc++ builds for ubuntu + job: + name: ${{ matrix.os }}.${{ matrix.build_type }}.${{ matrix.lib }}.${{ matrix.compiler }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-20.04, macos-latest] + build_type: ['Release', 'Debug'] + compiler: ['g++', 'clang++'] + lib: ['shared', 'static'] + + steps: + - uses: actions/checkout@v4 + + - uses: lukka/get-cmake@latest + + - name: create build environment + run: cmake -E make_directory ${{ runner.workspace }}/_build + + - name: setup cmake initial cache + run: touch compiler-cache.cmake + + - name: configure cmake + env: + CXX: ${{ matrix.compiler }} + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: > + cmake -C ${{ github.workspace }}/compiler-cache.cmake + $GITHUB_WORKSPACE + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DBUILD_SHARED_LIBS=${{ matrix.lib == 'shared' }} + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + -DCMAKE_CXX_COMPILER=${{ env.CXX }} + -DCMAKE_CXX_VISIBILITY_PRESET=hidden + -DCMAKE_VISIBILITY_INLINES_HIDDEN=ON + + - name: build + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: cmake --build . --config ${{ matrix.build_type }} + + - name: test + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: ctest -C ${{ matrix.build_type }} -VV + + msvc: + name: ${{ matrix.os }}.${{ matrix.build_type }}.${{ matrix.lib }}.${{ matrix.msvc }} + runs-on: ${{ matrix.os }} + defaults: + run: + shell: powershell + strategy: + fail-fast: false + matrix: + msvc: + - VS-16-2019 + - VS-17-2022 + arch: + - x64 + build_type: + - Debug + - Release + lib: + - shared + - static + include: + - msvc: VS-16-2019 + os: windows-2019 + generator: 'Visual Studio 16 2019' + - msvc: VS-17-2022 + os: windows-2022 + generator: 'Visual Studio 17 2022' + + steps: + - uses: actions/checkout@v4 + + - uses: lukka/get-cmake@latest + + - name: configure cmake + run: > + cmake -S . -B _build/ + -A ${{ matrix.arch }} + -G "${{ matrix.generator }}" + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DBUILD_SHARED_LIBS=${{ matrix.lib == 'shared' }} + + - name: build + run: cmake --build _build/ --config ${{ matrix.build_type }} + + - name: test + run: ctest --test-dir _build/ -C ${{ matrix.build_type }} -VV + + msys2: + name: ${{ matrix.os }}.${{ matrix.build_type }}.${{ matrix.lib }}.${{ matrix.msys2.msystem }} + runs-on: ${{ matrix.os }} + defaults: + run: + shell: msys2 {0} + strategy: + fail-fast: false + matrix: + os: [ windows-latest ] + msys2: + - { msystem: MINGW64, arch: x86_64, family: GNU, compiler: g++ } + - { msystem: MINGW32, arch: i686, family: GNU, compiler: g++ } + - { msystem: CLANG64, arch: x86_64, family: LLVM, compiler: clang++ } + - { msystem: CLANG32, arch: i686, family: LLVM, compiler: clang++ } + - { msystem: UCRT64, arch: x86_64, family: GNU, compiler: g++ } + build_type: + - Debug + - Release + lib: + - shared + - static + + steps: + - uses: actions/checkout@v4 + + - name: Install Base Dependencies + uses: msys2/setup-msys2@v2 + with: + cache: false + msystem: ${{ matrix.msys2.msystem }} + update: true + install: >- + git + base-devel + pacboy: >- + cc:p + cmake:p + ninja:p + + - name: configure cmake + env: + CXX: ${{ matrix.msys2.compiler }} + run: > + cmake -S . -B _build/ + -GNinja + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DBUILD_SHARED_LIBS=${{ matrix.lib == 'shared' }} + + - name: build + run: cmake --build _build/ --config ${{ matrix.build_type }} + + - name: test + run: ctest --test-dir _build/ -C ${{ matrix.build_type }} -VV diff --git a/third_party/benchmark/.github/workflows/clang-format-lint.yml b/third_party/benchmark/.github/workflows/clang-format-lint.yml new file mode 100644 index 0000000..8f089dc --- /dev/null +++ b/third_party/benchmark/.github/workflows/clang-format-lint.yml @@ -0,0 +1,18 @@ +name: clang-format-lint +on: + push: {} + pull_request: {} + +jobs: + job: + name: check-clang-format + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: DoozyX/clang-format-lint-action@v0.15 + with: + source: './include/benchmark ./src ./test' + extensions: 'h,cc' + clangFormatVersion: 12 + style: Google diff --git a/third_party/benchmark/.github/workflows/clang-tidy.yml b/third_party/benchmark/.github/workflows/clang-tidy.yml new file mode 100644 index 0000000..37a61cd --- /dev/null +++ b/third_party/benchmark/.github/workflows/clang-tidy.yml @@ -0,0 +1,38 @@ +name: clang-tidy + +on: + push: {} + pull_request: {} + +jobs: + job: + name: run-clang-tidy + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v4 + + - name: install clang-tidy + run: sudo apt update && sudo apt -y install clang-tidy + + - name: create build environment + run: cmake -E make_directory ${{ runner.workspace }}/_build + + - name: configure cmake + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: > + cmake $GITHUB_WORKSPACE + -DBENCHMARK_ENABLE_ASSEMBLY_TESTS=OFF + -DBENCHMARK_ENABLE_LIBPFM=OFF + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DCMAKE_C_COMPILER=clang + -DCMAKE_CXX_COMPILER=clang++ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + -DGTEST_COMPILE_COMMANDS=OFF + + - name: run + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: run-clang-tidy -checks=*,-clang-analyzer-deadcode* diff --git a/third_party/benchmark/.github/workflows/doxygen.yml b/third_party/benchmark/.github/workflows/doxygen.yml new file mode 100644 index 0000000..40c1cb4 --- /dev/null +++ b/third_party/benchmark/.github/workflows/doxygen.yml @@ -0,0 +1,28 @@ +name: doxygen + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build-and-deploy: + name: Build HTML documentation + runs-on: ubuntu-latest + steps: + - name: Fetching sources + uses: actions/checkout@v4 + + - name: Installing build dependencies + run: | + sudo apt update + sudo apt install doxygen gcc git + + - name: Creating build directory + run: mkdir build + + - name: Building HTML documentation with Doxygen + run: | + cmake -S . -B build -DBENCHMARK_ENABLE_TESTING:BOOL=OFF -DBENCHMARK_ENABLE_DOXYGEN:BOOL=ON -DBENCHMARK_INSTALL_DOCS:BOOL=ON + cmake --build build --target benchmark_doxygen diff --git a/third_party/benchmark/.github/workflows/pre-commit.yml b/third_party/benchmark/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..8b217e9 --- /dev/null +++ b/third_party/benchmark/.github/workflows/pre-commit.yml @@ -0,0 +1,38 @@ +name: python + Bazel pre-commit checks + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + env: + MYPY_CACHE_DIR: "${{ github.workspace }}/.cache/mypy" + RUFF_CACHE_DIR: "${{ github.workspace }}/.cache/ruff" + PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pre-commit" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.11 + cache: pip + cache-dependency-path: pyproject.toml + - name: Install dependencies + run: python -m pip install ".[dev]" + - name: Cache pre-commit tools + uses: actions/cache@v4 + with: + path: | + ${{ env.MYPY_CACHE_DIR }} + ${{ env.RUFF_CACHE_DIR }} + ${{ env.PRE_COMMIT_HOME }} + key: ${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}-linter-cache + - name: Run pre-commit checks + run: pre-commit run --all-files --verbose --show-diff-on-failure diff --git a/third_party/benchmark/.github/workflows/sanitizer.yml b/third_party/benchmark/.github/workflows/sanitizer.yml new file mode 100644 index 0000000..4992153 --- /dev/null +++ b/third_party/benchmark/.github/workflows/sanitizer.yml @@ -0,0 +1,96 @@ +name: sanitizer + +on: + push: {} + pull_request: {} + +env: + UBSAN_OPTIONS: "print_stacktrace=1" + +jobs: + job: + name: ${{ matrix.sanitizer }}.${{ matrix.build_type }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + build_type: ['Debug', 'RelWithDebInfo'] + sanitizer: ['asan', 'ubsan', 'tsan', 'msan'] + + steps: + - uses: actions/checkout@v4 + + - name: configure msan env + if: matrix.sanitizer == 'msan' + run: | + echo "EXTRA_FLAGS=-g -O2 -fno-omit-frame-pointer -fsanitize=memory -fsanitize-memory-track-origins" >> $GITHUB_ENV + echo "LIBCXX_SANITIZER=MemoryWithOrigins" >> $GITHUB_ENV + + - name: configure ubsan env + if: matrix.sanitizer == 'ubsan' + run: | + echo "EXTRA_FLAGS=-g -O2 -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=all" >> $GITHUB_ENV + echo "LIBCXX_SANITIZER=Undefined" >> $GITHUB_ENV + + - name: configure asan env + if: matrix.sanitizer == 'asan' + run: | + echo "EXTRA_FLAGS=-g -O2 -fno-omit-frame-pointer -fsanitize=address -fno-sanitize-recover=all" >> $GITHUB_ENV + echo "LIBCXX_SANITIZER=Address" >> $GITHUB_ENV + + - name: configure tsan env + if: matrix.sanitizer == 'tsan' + run: | + echo "EXTRA_FLAGS=-g -O2 -fno-omit-frame-pointer -fsanitize=thread -fno-sanitize-recover=all" >> $GITHUB_ENV + echo "LIBCXX_SANITIZER=Thread" >> $GITHUB_ENV + + - name: fine-tune asan options + # in asan we get an error from std::regex. ignore it. + if: matrix.sanitizer == 'asan' + run: | + echo "ASAN_OPTIONS=alloc_dealloc_mismatch=0" >> $GITHUB_ENV + + - name: setup clang + uses: egor-tensin/setup-clang@v1 + with: + version: latest + platform: x64 + + - name: configure clang + run: | + echo "CC=cc" >> $GITHUB_ENV + echo "CXX=c++" >> $GITHUB_ENV + + - name: build libc++ (non-asan) + if: matrix.sanitizer != 'asan' + run: | + "${GITHUB_WORKSPACE}/.github/libcxx-setup.sh" + echo "EXTRA_CXX_FLAGS=-stdlib=libc++ -L ${GITHUB_WORKSPACE}/llvm-build/lib -lc++abi -Isystem${GITHUB_WORKSPACE}/llvm-build/include -Isystem${GITHUB_WORKSPACE}/llvm-build/include/c++/v1 -Wl,-rpath,${GITHUB_WORKSPACE}/llvm-build/lib" >> $GITHUB_ENV + + - name: create build environment + run: cmake -E make_directory ${{ runner.workspace }}/_build + + - name: configure cmake + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: > + VERBOSE=1 + cmake $GITHUB_WORKSPACE + -DBENCHMARK_ENABLE_ASSEMBLY_TESTS=OFF + -DBENCHMARK_ENABLE_LIBPFM=OFF + -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON + -DCMAKE_C_COMPILER=${{ env.CC }} + -DCMAKE_CXX_COMPILER=${{ env.CXX }} + -DCMAKE_C_FLAGS="${{ env.EXTRA_FLAGS }}" + -DCMAKE_CXX_FLAGS="${{ env.EXTRA_FLAGS }} ${{ env.EXTRA_CXX_FLAGS }}" + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + + - name: build + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: cmake --build . --config ${{ matrix.build_type }} + + - name: test + shell: bash + working-directory: ${{ runner.workspace }}/_build + run: ctest -C ${{ matrix.build_type }} -VV diff --git a/third_party/benchmark/.github/workflows/test_bindings.yml b/third_party/benchmark/.github/workflows/test_bindings.yml new file mode 100644 index 0000000..b6ac9be --- /dev/null +++ b/third_party/benchmark/.github/workflows/test_bindings.yml @@ -0,0 +1,30 @@ +name: test-bindings + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + python_bindings: + name: Test GBM Python ${{ matrix.python-version }} bindings on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ ubuntu-latest, macos-latest, windows-latest ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install GBM Python bindings on ${{ matrix.os }} + run: python -m pip install . + - name: Run example on ${{ matrix.os }} under Python ${{ matrix.python-version }} + run: python bindings/python/google_benchmark/example.py diff --git a/third_party/benchmark/.github/workflows/wheels.yml b/third_party/benchmark/.github/workflows/wheels.yml new file mode 100644 index 0000000..b463ff8 --- /dev/null +++ b/third_party/benchmark/.github/workflows/wheels.yml @@ -0,0 +1,100 @@ +name: Build and upload Python wheels + +on: + workflow_dispatch: + release: + types: + - published + +jobs: + build_sdist: + name: Build source distribution + runs-on: ubuntu-latest + steps: + - name: Check out repo + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: python -m pip install build + - name: Build sdist + run: python -m build --sdist + - uses: actions/upload-artifact@v4 + with: + name: dist-sdist + path: dist/*.tar.gz + + build_wheels: + name: Build Google Benchmark wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-13, macos-14, windows-latest] + + steps: + - name: Check out Google Benchmark + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + name: Install Python 3.12 + with: + python-version: "3.12" + - run: pip install --upgrade pip uv + + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v3 + with: + platforms: all + + - name: Build wheels on ${{ matrix.os }} using cibuildwheel + uses: pypa/cibuildwheel@v2.21.3 + env: + CIBW_BUILD: "cp310-* cp311-* cp312-*" + CIBW_BUILD_FRONTEND: "build[uv]" + CIBW_SKIP: "*-musllinux_*" + CIBW_TEST_SKIP: "cp38-macosx_*:arm64" + CIBW_ARCHS_LINUX: auto64 aarch64 + CIBW_ARCHS_WINDOWS: auto64 + CIBW_BEFORE_ALL_LINUX: bash .github/install_bazel.sh + # Grab the rootless Bazel installation inside the container. + CIBW_ENVIRONMENT_LINUX: PATH=$PATH:$HOME/bin + CIBW_TEST_COMMAND: python {project}/bindings/python/google_benchmark/example.py + # unused by Bazel, but needed explicitly by delocate on MacOS. + MACOSX_DEPLOYMENT_TARGET: "10.14" + + - name: Upload Google Benchmark ${{ matrix.os }} wheels + uses: actions/upload-artifact@v4 + with: + name: dist-${{ matrix.os }} + path: wheelhouse/*.whl + + merge_wheels: + name: Merge all built wheels into one artifact + runs-on: ubuntu-latest + needs: build_wheels + steps: + - name: Merge wheels + uses: actions/upload-artifact/merge@v4 + with: + name: dist + pattern: dist-* + delete-merged: true + + pypi_upload: + name: Publish google-benchmark wheels to PyPI + needs: [merge_wheels] + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + path: dist + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/third_party/benchmark/.gitignore b/third_party/benchmark/.gitignore new file mode 100644 index 0000000..24a1fb6 --- /dev/null +++ b/third_party/benchmark/.gitignore @@ -0,0 +1,68 @@ +*.a +*.so +*.so.?* +*.dll +*.exe +*.dylib +*.cmake +!/cmake/*.cmake +!/test/AssemblyTests.cmake +*~ +*.swp +*.pyc +__pycache__ +.DS_Store + +# lcov +*.lcov +/lcov + +# cmake files. +/Testing +CMakeCache.txt +CMakeFiles/ +cmake_install.cmake + +# makefiles. +Makefile + +# in-source build. +bin/ +lib/ +/test/*_test + +# exuberant ctags. +tags + +# YouCompleteMe configuration. +.ycm_extra_conf.pyc + +# ninja generated files. +.ninja_deps +.ninja_log +build.ninja +install_manifest.txt +rules.ninja + +# bazel output symlinks. +bazel-* +MODULE.bazel.lock + +# out-of-source build top-level folders. +build/ +_build/ +build*/ + +# in-source dependencies +/googletest/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +CMakeSettings.json + +# Visual Studio Code cache/options directory +.vscode/ + +# Python build stuff +dist/ +*.egg-info* diff --git a/third_party/benchmark/.pre-commit-config.yaml b/third_party/benchmark/.pre-commit-config.yaml new file mode 100644 index 0000000..2a51592 --- /dev/null +++ b/third_party/benchmark/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/keith/pre-commit-buildifier + rev: 7.3.1 + hooks: + - id: buildifier + - id: buildifier-lint + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + types_or: [ python, pyi ] + args: [ "--ignore-missing-imports", "--scripts-are-modules" ] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + - id: ruff-format diff --git a/third_party/benchmark/.ycm_extra_conf.py b/third_party/benchmark/.ycm_extra_conf.py new file mode 100644 index 0000000..caf257f --- /dev/null +++ b/third_party/benchmark/.ycm_extra_conf.py @@ -0,0 +1,120 @@ +import os + +import ycm_core + +# These are the compilation flags that will be used in case there's no +# compilation database set (by default, one is not set). +# CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR. +flags = [ + "-Wall", + "-Werror", + "-pedantic-errors", + "-std=c++0x", + "-fno-strict-aliasing", + "-O3", + "-DNDEBUG", + # ...and the same thing goes for the magic -x option which specifies the + # language that the files to be compiled are written in. This is mostly + # relevant for c++ headers. + # For a C project, you would set this to 'c' instead of 'c++'. + "-x", + "c++", + "-I", + "include", + "-isystem", + "/usr/include", + "-isystem", + "/usr/local/include", +] + + +# Set this to the absolute path to the folder (NOT the file!) containing the +# compile_commands.json file to use that instead of 'flags'. See here for +# more details: http://clang.llvm.org/docs/JSONCompilationDatabase.html +# +# Most projects will NOT need to set this to anything; you can just change the +# 'flags' list of compilation flags. Notice that YCM itself uses that approach. +compilation_database_folder = "" + +if os.path.exists(compilation_database_folder): + database = ycm_core.CompilationDatabase(compilation_database_folder) +else: + database = None + +SOURCE_EXTENSIONS = [".cc"] + + +def DirectoryOfThisScript(): + return os.path.dirname(os.path.abspath(__file__)) + + +def MakeRelativePathsInFlagsAbsolute(flags, working_directory): + if not working_directory: + return list(flags) + new_flags = [] + make_next_absolute = False + path_flags = ["-isystem", "-I", "-iquote", "--sysroot="] + for flag in flags: + new_flag = flag + + if make_next_absolute: + make_next_absolute = False + if not flag.startswith("/"): + new_flag = os.path.join(working_directory, flag) + + for path_flag in path_flags: + if flag == path_flag: + make_next_absolute = True + break + + if flag.startswith(path_flag): + path = flag[len(path_flag) :] + new_flag = path_flag + os.path.join(working_directory, path) + break + + if new_flag: + new_flags.append(new_flag) + return new_flags + + +def IsHeaderFile(filename): + extension = os.path.splitext(filename)[1] + return extension in [".h", ".hxx", ".hpp", ".hh"] + + +def GetCompilationInfoForFile(filename): + # The compilation_commands.json file generated by CMake does not have entries + # for header files. So we do our best by asking the db for flags for a + # corresponding source file, if any. If one exists, the flags for that file + # should be good enough. + if IsHeaderFile(filename): + basename = os.path.splitext(filename)[0] + for extension in SOURCE_EXTENSIONS: + replacement_file = basename + extension + if os.path.exists(replacement_file): + compilation_info = database.GetCompilationInfoForFile( + replacement_file + ) + if compilation_info.compiler_flags_: + return compilation_info + return None + return database.GetCompilationInfoForFile(filename) + + +def FlagsForFile(filename, **kwargs): + if database: + # Bear in mind that compilation_info.compiler_flags_ does NOT return a + # python list, but a "list-like" StringVec object + compilation_info = GetCompilationInfoForFile(filename) + if not compilation_info: + return None + + final_flags = MakeRelativePathsInFlagsAbsolute( + compilation_info.compiler_flags_, + compilation_info.compiler_working_dir_, + ) + else: + relative_to = DirectoryOfThisScript() + final_flags = MakeRelativePathsInFlagsAbsolute(flags, relative_to) + + return {"flags": final_flags, "do_cache": True} diff --git a/third_party/benchmark/AUTHORS b/third_party/benchmark/AUTHORS new file mode 100644 index 0000000..2170e46 --- /dev/null +++ b/third_party/benchmark/AUTHORS @@ -0,0 +1,72 @@ +# This is the official list of benchmark authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. +# +# Names should be added to this file as: +# Name or Organization +# The email address is not required for organizations. +# +# Please keep the list sorted. + +Albert Pretorius +Alex Steele +Andriy Berestovskyy +Arne Beer +Carto +Cezary Skrzyński +Christian Wassermann +Christopher Seymour +Colin Braley +Daniel Harvey +David Coeurjolly +Deniz Evrenci +Dirac Research +Dominik Czarnota +Dominik Korman +Donald Aingworth +Eric Backus +Eric Fiselier +Eugene Zhuk +Evgeny Safronov +Fabien Pichot +Federico Ficarelli +Felix Homann +Gergely Meszaros +Gergő Szitár +Google Inc. +Henrique Bucher +International Business Machines Corporation +Ismael Jimenez Martinez +Jern-Kuan Leong +JianXiong Zhou +Joao Paulo Magalhaes +Jordan Williams +Jussi Knuuttila +Kaito Udagawa +Kishan Kumar +Lei Xu +Marcel Jacobse +Matt Clarkson +Maxim Vafin +Mike Apodaca +Min-Yih Hsu +MongoDB Inc. +Nick Hutchinson +Norman Heino +Oleksandr Sochka +Ori Livneh +Paul Redmond +Radoslav Yovchev +Raghu Raja +Rainer Orth +Roman Lebedev +Sayan Bhattacharjee +Shapr3D +Shuo Chen +Staffan Tjernstrom +Steinar H. Gunderson +Stripe, Inc. +Tobias Schmidt +Yixuan Qiu +Yusuke Suzuki +Zbigniew Skowron diff --git a/third_party/benchmark/BUILD.bazel b/third_party/benchmark/BUILD.bazel new file mode 100644 index 0000000..3451b4e --- /dev/null +++ b/third_party/benchmark/BUILD.bazel @@ -0,0 +1,114 @@ +licenses(["notice"]) + +COPTS = [ + "-pedantic", + "-pedantic-errors", + "-std=c++17", + "-Wall", + "-Wconversion", + "-Wextra", + "-Wshadow", + # "-Wshorten-64-to-32", + "-Wfloat-equal", + "-fstrict-aliasing", + ## assert() are used a lot in tests upstream, which may be optimised out leading to + ## unused-variable warning. + "-Wno-unused-variable", + "-Werror=old-style-cast", +] + +config_setting( + name = "qnx", + constraint_values = ["@platforms//os:qnx"], + values = { + "cpu": "x64_qnx", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], + values = { + "cpu": "x64_windows", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "macos", + constraint_values = ["@platforms//os:macos"], + visibility = ["//visibility:public"], +) + +config_setting( + name = "perfcounters", + define_values = { + "pfm": "1", + }, + visibility = [":__subpackages__"], +) + +cc_library( + name = "benchmark", + srcs = glob( + [ + "src/*.cc", + "src/*.h", + ], + exclude = ["src/benchmark_main.cc"], + ), + hdrs = [ + "include/benchmark/benchmark.h", + "include/benchmark/export.h", + ], + copts = select({ + ":windows": [], + "//conditions:default": COPTS, + }), + defines = [ + "BENCHMARK_STATIC_DEFINE", + "BENCHMARK_VERSION=\\\"" + (module_version() if module_version() != None else "") + "\\\"", + ] + select({ + ":perfcounters": ["HAVE_LIBPFM"], + "//conditions:default": [], + }), + includes = ["include"], + linkopts = select({ + ":windows": ["-DEFAULTLIB:shlwapi.lib"], + "//conditions:default": ["-pthread"], + }), + # Only static linking is allowed; no .so will be produced. + # Using `defines` (i.e. not `local_defines`) means that no + # dependent rules need to bother about defining the macro. + linkstatic = True, + local_defines = [ + # Turn on Large-file Support + "_FILE_OFFSET_BITS=64", + "_LARGEFILE64_SOURCE", + "_LARGEFILE_SOURCE", + ], + visibility = ["//visibility:public"], + deps = select({ + ":perfcounters": ["@libpfm"], + "//conditions:default": [], + }), +) + +cc_library( + name = "benchmark_main", + srcs = ["src/benchmark_main.cc"], + hdrs = [ + "include/benchmark/benchmark.h", + "include/benchmark/export.h", + ], + includes = ["include"], + visibility = ["//visibility:public"], + deps = [":benchmark"], +) + +cc_library( + name = "benchmark_internal_headers", + hdrs = glob(["src/*.h"]), + visibility = ["//test:__pkg__"], +) diff --git a/third_party/benchmark/CMakeLists.txt b/third_party/benchmark/CMakeLists.txt new file mode 100644 index 0000000..f045fcd --- /dev/null +++ b/third_party/benchmark/CMakeLists.txt @@ -0,0 +1,355 @@ +# Require CMake 3.10. If available, use the policies up to CMake 3.22. +cmake_minimum_required (VERSION 3.13...3.22) + +project (benchmark VERSION 1.9.1 LANGUAGES CXX) + +option(BENCHMARK_ENABLE_TESTING "Enable testing of the benchmark library." ON) +option(BENCHMARK_ENABLE_EXCEPTIONS "Enable the use of exceptions in the benchmark library." ON) +option(BENCHMARK_ENABLE_LTO "Enable link time optimisation of the benchmark library." OFF) +option(BENCHMARK_USE_LIBCXX "Build and test using libc++ as the standard library." OFF) +option(BENCHMARK_ENABLE_WERROR "Build Release candidates with -Werror." ON) +option(BENCHMARK_FORCE_WERROR "Build Release candidates with -Werror regardless of compiler issues." OFF) + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "PGI") + # PGC++ maybe reporting false positives. + set(BENCHMARK_ENABLE_WERROR OFF) +endif() +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "NVHPC") + set(BENCHMARK_ENABLE_WERROR OFF) +endif() +if(BENCHMARK_FORCE_WERROR) + set(BENCHMARK_ENABLE_WERROR ON) +endif(BENCHMARK_FORCE_WERROR) + +if(NOT (MSVC OR CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC")) + option(BENCHMARK_BUILD_32_BITS "Build a 32 bit version of the library." OFF) +else() + set(BENCHMARK_BUILD_32_BITS OFF CACHE BOOL "Build a 32 bit version of the library - unsupported when using MSVC)" FORCE) +endif() +option(BENCHMARK_ENABLE_INSTALL "Enable installation of benchmark. (Projects embedding benchmark may want to turn this OFF.)" ON) +option(BENCHMARK_ENABLE_DOXYGEN "Build documentation with Doxygen." OFF) +option(BENCHMARK_INSTALL_DOCS "Enable installation of documentation." ON) + +# Allow unmet dependencies to be met using CMake's ExternalProject mechanics, which +# may require downloading the source code. +option(BENCHMARK_DOWNLOAD_DEPENDENCIES "Allow the downloading and in-tree building of unmet dependencies" OFF) + +# This option can be used to disable building and running unit tests which depend on gtest +# in cases where it is not possible to build or find a valid version of gtest. +option(BENCHMARK_ENABLE_GTEST_TESTS "Enable building the unit tests which depend on gtest" ON) +option(BENCHMARK_USE_BUNDLED_GTEST "Use bundled GoogleTest. If disabled, the find_package(GTest) will be used." ON) + +option(BENCHMARK_ENABLE_LIBPFM "Enable performance counters provided by libpfm" OFF) + +# Export only public symbols +set(CMAKE_CXX_VISIBILITY_PRESET hidden) +set(CMAKE_VISIBILITY_INLINES_HIDDEN ON) + +if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # As of CMake 3.18, CMAKE_SYSTEM_PROCESSOR is not set properly for MSVC and + # cross-compilation (e.g. Host=x86_64, target=aarch64) requires using the + # undocumented, but working variable. + # See https://gitlab.kitware.com/cmake/cmake/-/issues/15170 + set(CMAKE_SYSTEM_PROCESSOR ${MSVC_CXX_ARCHITECTURE_ID}) + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM") + set(CMAKE_CROSSCOMPILING TRUE) + endif() +endif() + +set(ENABLE_ASSEMBLY_TESTS_DEFAULT OFF) +function(should_enable_assembly_tests) + if(CMAKE_BUILD_TYPE) + string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) + if (${CMAKE_BUILD_TYPE_LOWER} MATCHES "coverage") + # FIXME: The --coverage flag needs to be removed when building assembly + # tests for this to work. + return() + endif() + endif() + if (MSVC OR CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC") + return() + elseif(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + return() + elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + # FIXME: Make these work on 32 bit builds + return() + elseif(BENCHMARK_BUILD_32_BITS) + # FIXME: Make these work on 32 bit builds + return() + endif() + find_program(LLVM_FILECHECK_EXE FileCheck) + if (LLVM_FILECHECK_EXE) + set(LLVM_FILECHECK_EXE "${LLVM_FILECHECK_EXE}" CACHE PATH "llvm filecheck" FORCE) + message(STATUS "LLVM FileCheck Found: ${LLVM_FILECHECK_EXE}") + else() + message(STATUS "Failed to find LLVM FileCheck") + return() + endif() + set(ENABLE_ASSEMBLY_TESTS_DEFAULT ON PARENT_SCOPE) +endfunction() +should_enable_assembly_tests() + +# This option disables the building and running of the assembly verification tests +option(BENCHMARK_ENABLE_ASSEMBLY_TESTS "Enable building and running the assembly tests" + ${ENABLE_ASSEMBLY_TESTS_DEFAULT}) + +# Make sure we can import out CMake functions +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + + +# Read the git tags to determine the project version +include(GetGitVersion) +get_git_version(GIT_VERSION) + +# If no git version can be determined, use the version +# from the project() command +if ("${GIT_VERSION}" STREQUAL "v0.0.0") + set(VERSION "v${benchmark_VERSION}") +else() + set(VERSION "${GIT_VERSION}") +endif() + +# Normalize version: drop "v" prefix, replace first "-" with ".", +# drop everything after second "-" (including said "-"). +string(STRIP ${VERSION} VERSION) +if(VERSION MATCHES v[^-]*-) + string(REGEX REPLACE "v([^-]*)-([0-9]+)-.*" "\\1.\\2" NORMALIZED_VERSION ${VERSION}) +else() + string(REGEX REPLACE "v(.*)" "\\1" NORMALIZED_VERSION ${VERSION}) +endif() + +# Tell the user what versions we are using +message(STATUS "Google Benchmark version: ${VERSION}, normalized to ${NORMALIZED_VERSION}") + +# The version of the libraries +set(GENERIC_LIB_VERSION ${NORMALIZED_VERSION}) +string(SUBSTRING ${NORMALIZED_VERSION} 0 1 GENERIC_LIB_SOVERSION) + +# Import our CMake modules +include(AddCXXCompilerFlag) +include(CheckCXXCompilerFlag) +include(CheckLibraryExists) +include(CXXFeatureCheck) + +check_library_exists(rt shm_open "" HAVE_LIB_RT) + +if (BENCHMARK_BUILD_32_BITS) + add_required_cxx_compiler_flag(-m32) +endif() + +set(BENCHMARK_CXX_STANDARD 17) + +set(CMAKE_CXX_STANDARD ${BENCHMARK_CXX_STANDARD}) +set(CMAKE_CXX_STANDARD_REQUIRED YES) +set(CMAKE_CXX_EXTENSIONS OFF) + +if (MSVC) + # Turn compiler warnings up to 11 + string(REGEX REPLACE "[-/]W[1-4]" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 /MP") + add_definitions(-D_CRT_SECURE_NO_WARNINGS) + + if(BENCHMARK_ENABLE_WERROR) + add_cxx_compiler_flag(-WX) + endif() + + if (NOT BENCHMARK_ENABLE_EXCEPTIONS) + add_cxx_compiler_flag(-EHs-) + add_cxx_compiler_flag(-EHa-) + add_definitions(-D_HAS_EXCEPTIONS=0) + endif() + # Link time optimisation + if (BENCHMARK_ENABLE_LTO) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /GL") + set(CMAKE_STATIC_LINKER_FLAGS_RELEASE "${CMAKE_STATIC_LINKER_FLAGS_RELEASE} /LTCG") + set(CMAKE_SHARED_LINKER_FLAGS_RELEASE "${CMAKE_SHARED_LINKER_FLAGS_RELEASE} /LTCG") + set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} /LTCG") + + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /GL") + string(REGEX REPLACE "[-/]INCREMENTAL" "/INCREMENTAL:NO" CMAKE_STATIC_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_STATIC_LINKER_FLAGS_RELWITHDEBINFO}") + set(CMAKE_STATIC_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_STATIC_LINKER_FLAGS_RELWITHDEBINFO} /LTCG") + string(REGEX REPLACE "[-/]INCREMENTAL" "/INCREMENTAL:NO" CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO}") + set(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO} /LTCG") + string(REGEX REPLACE "[-/]INCREMENTAL" "/INCREMENTAL:NO" CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO}") + set(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO} /LTCG") + + set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /GL") + set(CMAKE_STATIC_LINKER_FLAGS_MINSIZEREL "${CMAKE_STATIC_LINKER_FLAGS_MINSIZEREL} /LTCG") + set(CMAKE_SHARED_LINKER_FLAGS_MINSIZEREL "${CMAKE_SHARED_LINKER_FLAGS_MINSIZEREL} /LTCG") + set(CMAKE_EXE_LINKER_FLAGS_MINSIZEREL "${CMAKE_EXE_LINKER_FLAGS_MINSIZEREL} /LTCG") + endif() +else() + # Turn on Large-file Support + add_definitions(-D_FILE_OFFSET_BITS=64) + add_definitions(-D_LARGEFILE64_SOURCE) + add_definitions(-D_LARGEFILE_SOURCE) + # Turn compiler warnings up to 11 + add_cxx_compiler_flag(-Wall) + add_cxx_compiler_flag(-Wextra) + add_cxx_compiler_flag(-Wshadow) + add_cxx_compiler_flag(-Wfloat-equal) + add_cxx_compiler_flag(-Wold-style-cast) + add_cxx_compiler_flag(-Wconversion) + if(BENCHMARK_ENABLE_WERROR) + add_cxx_compiler_flag(-Werror) + endif() + if (NOT BENCHMARK_ENABLE_TESTING) + # Disable warning when compiling tests as gtest does not use 'override'. + add_cxx_compiler_flag(-Wsuggest-override) + endif() + add_cxx_compiler_flag(-pedantic) + add_cxx_compiler_flag(-pedantic-errors) + add_cxx_compiler_flag(-Wshorten-64-to-32) + add_cxx_compiler_flag(-fstrict-aliasing) + # Disable warnings regarding deprecated parts of the library while building + # and testing those parts of the library. + add_cxx_compiler_flag(-Wno-deprecated-declarations) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Intel" OR CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + # Intel silently ignores '-Wno-deprecated-declarations', + # warning no. 1786 must be explicitly disabled. + # See #631 for rationale. + add_cxx_compiler_flag(-wd1786) + add_cxx_compiler_flag(-fno-finite-math-only) + endif() + # Disable deprecation warnings for release builds (when -Werror is enabled). + if(BENCHMARK_ENABLE_WERROR) + add_cxx_compiler_flag(-Wno-deprecated) + endif() + if (NOT BENCHMARK_ENABLE_EXCEPTIONS) + add_cxx_compiler_flag(-fno-exceptions) + endif() + + if (HAVE_CXX_FLAG_FSTRICT_ALIASING) + if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "Intel" AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") #ICC17u2: Many false positives for Wstrict-aliasing + add_cxx_compiler_flag(-Wstrict-aliasing) + endif() + endif() + # ICC17u2: overloaded virtual function "benchmark::Fixture::SetUp" is only partially overridden + # (because of deprecated overload) + add_cxx_compiler_flag(-wd654) + add_cxx_compiler_flag(-Wthread-safety) + if (HAVE_CXX_FLAG_WTHREAD_SAFETY) + cxx_feature_check(THREAD_SAFETY_ATTRIBUTES "-DINCLUDE_DIRECTORIES=${PROJECT_SOURCE_DIR}/include") + endif() + + # On most UNIX like platforms g++ and clang++ define _GNU_SOURCE as a + # predefined macro, which turns on all of the wonderful libc extensions. + # However g++ doesn't do this in Cygwin so we have to define it ourselves + # since we depend on GNU/POSIX/BSD extensions. + if (CYGWIN) + add_definitions(-D_GNU_SOURCE=1) + endif() + + if (QNXNTO) + add_definitions(-D_QNX_SOURCE) + endif() + + # Link time optimisation + if (BENCHMARK_ENABLE_LTO) + add_cxx_compiler_flag(-flto) + add_cxx_compiler_flag(-Wno-lto-type-mismatch) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + find_program(GCC_AR gcc-ar) + if (GCC_AR) + set(CMAKE_AR ${GCC_AR}) + endif() + find_program(GCC_RANLIB gcc-ranlib) + if (GCC_RANLIB) + set(CMAKE_RANLIB ${GCC_RANLIB}) + endif() + elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + include(llvm-toolchain) + endif() + endif() + + # Coverage build type + set(BENCHMARK_CXX_FLAGS_COVERAGE "${CMAKE_CXX_FLAGS_DEBUG}" + CACHE STRING "Flags used by the C++ compiler during coverage builds." + FORCE) + set(BENCHMARK_EXE_LINKER_FLAGS_COVERAGE "${CMAKE_EXE_LINKER_FLAGS_DEBUG}" + CACHE STRING "Flags used for linking binaries during coverage builds." + FORCE) + set(BENCHMARK_SHARED_LINKER_FLAGS_COVERAGE "${CMAKE_SHARED_LINKER_FLAGS_DEBUG}" + CACHE STRING "Flags used by the shared libraries linker during coverage builds." + FORCE) + mark_as_advanced( + BENCHMARK_CXX_FLAGS_COVERAGE + BENCHMARK_EXE_LINKER_FLAGS_COVERAGE + BENCHMARK_SHARED_LINKER_FLAGS_COVERAGE) + set(CMAKE_BUILD_TYPE "${CMAKE_BUILD_TYPE}" CACHE STRING + "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel Coverage.") + add_cxx_compiler_flag(--coverage COVERAGE) +endif() + +if (BENCHMARK_USE_LIBCXX) + if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + add_cxx_compiler_flag(-stdlib=libc++) + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel" OR + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "IntelLLVM") + add_cxx_compiler_flag(-nostdinc++) + message(WARNING "libc++ header path must be manually specified using CMAKE_CXX_FLAGS") + # Adding -nodefaultlibs directly to CMAKE__LINKER_FLAGS will break + # configuration checks such as 'find_package(Threads)' + list(APPEND BENCHMARK_CXX_LINKER_FLAGS -nodefaultlibs) + # -lc++ cannot be added directly to CMAKE__LINKER_FLAGS because + # linker flags appear before all linker inputs and -lc++ must appear after. + list(APPEND BENCHMARK_CXX_LIBRARIES c++) + else() + message(FATAL_ERROR "-DBENCHMARK_USE_LIBCXX:BOOL=ON is not supported for compiler") + endif() +endif(BENCHMARK_USE_LIBCXX) + +set(EXTRA_CXX_FLAGS "") +if (WIN32 AND "${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + # Clang on Windows fails to compile the regex feature check under C++11 + set(EXTRA_CXX_FLAGS "-DCMAKE_CXX_STANDARD=14") +endif() + +# C++ feature checks +# Determine the correct regular expression engine to use +cxx_feature_check(STD_REGEX ${EXTRA_CXX_FLAGS}) +cxx_feature_check(GNU_POSIX_REGEX ${EXTRA_CXX_FLAGS}) +cxx_feature_check(POSIX_REGEX ${EXTRA_CXX_FLAGS}) +if(NOT HAVE_STD_REGEX AND NOT HAVE_GNU_POSIX_REGEX AND NOT HAVE_POSIX_REGEX) + message(FATAL_ERROR "Failed to determine the source files for the regular expression backend") +endif() +if (NOT BENCHMARK_ENABLE_EXCEPTIONS AND HAVE_STD_REGEX + AND NOT HAVE_GNU_POSIX_REGEX AND NOT HAVE_POSIX_REGEX) + message(WARNING "Using std::regex with exceptions disabled is not fully supported") +endif() + +cxx_feature_check(STEADY_CLOCK) +# Ensure we have pthreads +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) +cxx_feature_check(PTHREAD_AFFINITY) + +if (BENCHMARK_ENABLE_LIBPFM) + find_package(PFM REQUIRED) +endif() + +# Set up directories +include_directories(${PROJECT_SOURCE_DIR}/include) + +# Build the targets +add_subdirectory(src) + +if (BENCHMARK_ENABLE_TESTING) + enable_testing() + if (BENCHMARK_ENABLE_GTEST_TESTS AND + NOT (TARGET gtest AND TARGET gtest_main AND + TARGET gmock AND TARGET gmock_main)) + if (BENCHMARK_USE_BUNDLED_GTEST) + include(GoogleTest) + else() + find_package(GTest CONFIG REQUIRED) + add_library(gtest ALIAS GTest::gtest) + add_library(gtest_main ALIAS GTest::gtest_main) + add_library(gmock ALIAS GTest::gmock) + add_library(gmock_main ALIAS GTest::gmock_main) + endif() + endif() + add_subdirectory(test) +endif() diff --git a/third_party/benchmark/CONTRIBUTING.md b/third_party/benchmark/CONTRIBUTING.md new file mode 100644 index 0000000..43de4c9 --- /dev/null +++ b/third_party/benchmark/CONTRIBUTING.md @@ -0,0 +1,58 @@ +# How to contribute # + +We'd love to accept your patches and contributions to this project. There are +a just a few small guidelines you need to follow. + + +## Contributor License Agreement ## + +Contributions to any Google project must be accompanied by a Contributor +License Agreement. This is not a copyright **assignment**, it simply gives +Google permission to use and redistribute your contributions as part of the +project. + + * If you are an individual writing original source code and you're sure you + own the intellectual property, then you'll need to sign an [individual + CLA][]. + + * If you work for a company that wants to allow you to contribute your work, + then you'll need to sign a [corporate CLA][]. + +You generally only need to submit a CLA once, so if you've already submitted +one (even if it was for a different project), you probably don't need to do it +again. + +[individual CLA]: https://developers.google.com/open-source/cla/individual +[corporate CLA]: https://developers.google.com/open-source/cla/corporate + +Once your CLA is submitted (or if you already submitted one for +another Google project), make a commit adding yourself to the +[AUTHORS][] and [CONTRIBUTORS][] files. This commit can be part +of your first [pull request][]. + +[AUTHORS]: AUTHORS +[CONTRIBUTORS]: CONTRIBUTORS + + +## Submitting a patch ## + + 1. It's generally best to start by opening a new issue describing the bug or + feature you're intending to fix. Even if you think it's relatively minor, + it's helpful to know what people are working on. Mention in the initial + issue that you are planning to work on that bug or feature so that it can + be assigned to you. + + 1. Follow the normal process of [forking][] the project, and setup a new + branch to work in. It's important that each group of changes be done in + separate branches in order to ensure that a pull request only includes the + commits related to that bug or feature. + + 1. Do your best to have [well-formed commit messages][] for each change. + This provides consistency throughout the project, and ensures that commit + messages are able to be formatted properly by various git tools. + + 1. Finally, push the commits to your fork and submit a [pull request][]. + +[forking]: https://help.github.com/articles/fork-a-repo +[well-formed commit messages]: http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html +[pull request]: https://help.github.com/articles/creating-a-pull-request diff --git a/third_party/benchmark/CONTRIBUTORS b/third_party/benchmark/CONTRIBUTORS new file mode 100644 index 0000000..54aba7b --- /dev/null +++ b/third_party/benchmark/CONTRIBUTORS @@ -0,0 +1,98 @@ +# People who have agreed to one of the CLAs and can contribute patches. +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# Names should be added to this file only after verifying that +# the individual or the individual's organization has agreed to +# the appropriate Contributor License Agreement, found here: +# +# https://developers.google.com/open-source/cla/individual +# https://developers.google.com/open-source/cla/corporate +# +# The agreement for individuals can be filled out on the web. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file, depending on whether the +# individual or corporate CLA was used. +# +# Names should be added to this file as: +# Name +# +# Please keep the list sorted. + +Abhina Sreeskantharajan +Albert Pretorius +Alex Steele +Andriy Berestovskyy +Arne Beer +Bátor Tallér +Billy Robert O'Neal III +Cezary Skrzyński +Chris Kennelly +Christian Wassermann +Christopher Seymour +Colin Braley +Cyrille Faucheux +Daniel Harvey +David Coeurjolly +Deniz Evrenci +Dominic Hamon +Dominik Czarnota +Dominik Korman +Donald Aingworth +Doug Evans +Eric Backus +Eric Fiselier +Eugene Zhuk +Evgeny Safronov +Fabien Pichot +Fanbo Meng +Federico Ficarelli +Felix Homann +Geoffrey Martin-Noble +Gergely Meszaros +Gergő Szitár +Hannes Hauswedell +Henrique Bucher +Ismael Jimenez Martinez +Iakov Sergeev +Jern-Kuan Leong +JianXiong Zhou +Joao Paulo Magalhaes +John Millikin +Jordan Williams +Jussi Knuuttila +Kaito Udagawa +Kai Wolf +Kishan Kumar +Lei Xu +Marcel Jacobse +Matt Clarkson +Maxim Vafin +Mike Apodaca +Min-Yih Hsu +Nick Hutchinson +Norman Heino +Oleksandr Sochka +Ori Livneh +Pascal Leroy +Paul Redmond +Pierre Phaneuf +Radoslav Yovchev +Raghu Raja +Rainer Orth +Raul Marin +Ray Glover +Robert Guo +Roman Lebedev +Sayan Bhattacharjee +Shuo Chen +Steven Wan +Tobias Schmidt +Tobias Ulvgård +Tom Madams +Yixuan Qiu +Yusuke Suzuki +Zbigniew Skowron diff --git a/third_party/benchmark/LICENSE b/third_party/benchmark/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/third_party/benchmark/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/benchmark/MODULE.bazel b/third_party/benchmark/MODULE.bazel new file mode 100644 index 0000000..62870f7 --- /dev/null +++ b/third_party/benchmark/MODULE.bazel @@ -0,0 +1,42 @@ +module( + name = "google_benchmark", + version = "1.9.1", +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "rules_foreign_cc", version = "0.10.1") +bazel_dep(name = "rules_cc", version = "0.0.9") + +bazel_dep(name = "rules_python", version = "0.37.0", dev_dependency = True) +bazel_dep(name = "googletest", version = "1.14.0", dev_dependency = True, repo_name = "com_google_googletest") + +bazel_dep(name = "libpfm", version = "4.11.0") + +# Register a toolchain for Python 3.9 to be able to build numpy. Python +# versions >=3.10 are problematic. +# A second reason for this is to be able to build Python hermetically instead +# of relying on the changing default version from rules_python. + +python = use_extension("@rules_python//python/extensions:python.bzl", "python", dev_dependency = True) +python.toolchain(python_version = "3.8") +python.toolchain(python_version = "3.9") +python.toolchain(python_version = "3.10") +python.toolchain(python_version = "3.11") +python.toolchain( + is_default = True, + python_version = "3.12", +) +python.toolchain(python_version = "3.13") + +pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip", dev_dependency = True) +pip.parse( + hub_name = "tools_pip_deps", + python_version = "3.9", + requirements_lock = "//tools:requirements.txt", +) +use_repo(pip, "tools_pip_deps") + +# -- bazel_dep definitions -- # + +bazel_dep(name = "nanobind_bazel", version = "2.2.0", dev_dependency = True) diff --git a/third_party/benchmark/README.md b/third_party/benchmark/README.md new file mode 100644 index 0000000..8e5428f --- /dev/null +++ b/third_party/benchmark/README.md @@ -0,0 +1,221 @@ +# Benchmark + +[![build-and-test](https://github.com/google/benchmark/workflows/build-and-test/badge.svg)](https://github.com/google/benchmark/actions?query=workflow%3Abuild-and-test) +[![bazel](https://github.com/google/benchmark/actions/workflows/bazel.yml/badge.svg)](https://github.com/google/benchmark/actions/workflows/bazel.yml) +[![pylint](https://github.com/google/benchmark/workflows/pylint/badge.svg)](https://github.com/google/benchmark/actions?query=workflow%3Apylint) +[![test-bindings](https://github.com/google/benchmark/workflows/test-bindings/badge.svg)](https://github.com/google/benchmark/actions?query=workflow%3Atest-bindings) +[![Coverage Status](https://coveralls.io/repos/google/benchmark/badge.svg)](https://coveralls.io/r/google/benchmark) + +[![Discord](https://discordapp.com/api/guilds/1125694995928719494/widget.png?style=shield)](https://discord.gg/cz7UX7wKC2) + +A library to benchmark code snippets, similar to unit tests. Example: + +```c++ +#include + +static void BM_SomeFunction(benchmark::State& state) { + // Perform setup here + for (auto _ : state) { + // This code gets timed + SomeFunction(); + } +} +// Register the function as a benchmark +BENCHMARK(BM_SomeFunction); +// Run the benchmark +BENCHMARK_MAIN(); +``` + +## Getting Started + +To get started, see [Requirements](#requirements) and +[Installation](#installation). See [Usage](#usage) for a full example and the +[User Guide](docs/user_guide.md) for a more comprehensive feature overview. + +It may also help to read the [Google Test documentation](https://github.com/google/googletest/blob/main/docs/primer.md) +as some of the structural aspects of the APIs are similar. + +## Resources + +[Discussion group](https://groups.google.com/d/forum/benchmark-discuss) + +IRC channels: +* [libera](https://libera.chat) #benchmark + +[Additional Tooling Documentation](docs/tools.md) + +[Assembly Testing Documentation](docs/AssemblyTests.md) + +[Building and installing Python bindings](docs/python_bindings.md) + +## Requirements + +The library can be used with C++03. However, it requires C++14 to build, +including compiler and standard library support. + +_See [dependencies.md](docs/dependencies.md) for more details regarding supported +compilers and standards._ + +If you have need for a particular compiler to be supported, patches are very welcome. + +See [Platform-Specific Build Instructions](docs/platform_specific_build_instructions.md). + +## Installation + +This describes the installation process using cmake. As pre-requisites, you'll +need git and cmake installed. + +_See [dependencies.md](docs/dependencies.md) for more details regarding supported +versions of build tools._ + +```bash +# Check out the library. +$ git clone https://github.com/google/benchmark.git +# Go to the library root directory +$ cd benchmark +# Make a build directory to place the build output. +$ cmake -E make_directory "build" +# Generate build system files with cmake, and download any dependencies. +$ cmake -E chdir "build" cmake -DBENCHMARK_DOWNLOAD_DEPENDENCIES=on -DCMAKE_BUILD_TYPE=Release ../ +# or, starting with CMake 3.13, use a simpler form: +# cmake -DCMAKE_BUILD_TYPE=Release -S . -B "build" +# Build the library. +$ cmake --build "build" --config Release +``` +This builds the `benchmark` and `benchmark_main` libraries and tests. +On a unix system, the build directory should now look something like this: + +``` +/benchmark + /build + /src + /libbenchmark.a + /libbenchmark_main.a + /test + ... +``` + +Next, you can run the tests to check the build. + +```bash +$ cmake -E chdir "build" ctest --build-config Release +``` + +If you want to install the library globally, also run: + +``` +sudo cmake --build "build" --config Release --target install +``` + +Note that Google Benchmark requires Google Test to build and run the tests. This +dependency can be provided two ways: + +* Checkout the Google Test sources into `benchmark/googletest`. +* Otherwise, if `-DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON` is specified during + configuration as above, the library will automatically download and build + any required dependencies. + +If you do not wish to build and run the tests, add `-DBENCHMARK_ENABLE_GTEST_TESTS=OFF` +to `CMAKE_ARGS`. + +### Debug vs Release + +By default, benchmark builds as a debug library. You will see a warning in the +output when this is the case. To build it as a release library instead, add +`-DCMAKE_BUILD_TYPE=Release` when generating the build system files, as shown +above. The use of `--config Release` in build commands is needed to properly +support multi-configuration tools (like Visual Studio for example) and can be +skipped for other build systems (like Makefile). + +To enable link-time optimisation, also add `-DBENCHMARK_ENABLE_LTO=true` when +generating the build system files. + +If you are using gcc, you might need to set `GCC_AR` and `GCC_RANLIB` cmake +cache variables, if autodetection fails. + +If you are using clang, you may need to set `LLVMAR_EXECUTABLE`, +`LLVMNM_EXECUTABLE` and `LLVMRANLIB_EXECUTABLE` cmake cache variables. + +To enable sanitizer checks (eg., `asan` and `tsan`), add: +``` + -DCMAKE_C_FLAGS="-g -O2 -fno-omit-frame-pointer -fsanitize=address -fsanitize=thread -fno-sanitize-recover=all" + -DCMAKE_CXX_FLAGS="-g -O2 -fno-omit-frame-pointer -fsanitize=address -fsanitize=thread -fno-sanitize-recover=all " +``` + +### Stable and Experimental Library Versions + +The main branch contains the latest stable version of the benchmarking library; +the API of which can be considered largely stable, with source breaking changes +being made only upon the release of a new major version. + +Newer, experimental, features are implemented and tested on the +[`v2` branch](https://github.com/google/benchmark/tree/v2). Users who wish +to use, test, and provide feedback on the new features are encouraged to try +this branch. However, this branch provides no stability guarantees and reserves +the right to change and break the API at any time. + +## Usage + +### Basic usage + +Define a function that executes the code to measure, register it as a benchmark +function using the `BENCHMARK` macro, and ensure an appropriate `main` function +is available: + +```c++ +#include + +static void BM_StringCreation(benchmark::State& state) { + for (auto _ : state) + std::string empty_string; +} +// Register the function as a benchmark +BENCHMARK(BM_StringCreation); + +// Define another benchmark +static void BM_StringCopy(benchmark::State& state) { + std::string x = "hello"; + for (auto _ : state) + std::string copy(x); +} +BENCHMARK(BM_StringCopy); + +BENCHMARK_MAIN(); +``` + +To run the benchmark, compile and link against the `benchmark` library +(libbenchmark.a/.so). If you followed the build steps above, this library will +be under the build directory you created. + +```bash +# Example on linux after running the build steps above. Assumes the +# `benchmark` and `build` directories are under the current directory. +$ g++ mybenchmark.cc -std=c++11 -isystem benchmark/include \ + -Lbenchmark/build/src -lbenchmark -lpthread -o mybenchmark +``` + +Alternatively, link against the `benchmark_main` library and remove +`BENCHMARK_MAIN();` above to get the same behavior. + +The compiled executable will run all benchmarks by default. Pass the `--help` +flag for option information or see the [User Guide](docs/user_guide.md). + +### Usage with CMake + +If using CMake, it is recommended to link against the project-provided +`benchmark::benchmark` and `benchmark::benchmark_main` targets using +`target_link_libraries`. +It is possible to use ```find_package``` to import an installed version of the +library. +```cmake +find_package(benchmark REQUIRED) +``` +Alternatively, ```add_subdirectory``` will incorporate the library directly in +to one's CMake project. +```cmake +add_subdirectory(benchmark) +``` +Either way, link to the library as follows. +```cmake +target_link_libraries(MyTarget benchmark::benchmark) +``` diff --git a/third_party/benchmark/WORKSPACE b/third_party/benchmark/WORKSPACE new file mode 100644 index 0000000..5032024 --- /dev/null +++ b/third_party/benchmark/WORKSPACE @@ -0,0 +1,24 @@ +workspace(name = "com_github_google_benchmark") + +load("//:bazel/benchmark_deps.bzl", "benchmark_deps") + +benchmark_deps() + +load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") + +rules_foreign_cc_dependencies() + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + +load("@rules_python//python:pip.bzl", "pip_parse") + +pip_parse( + name = "tools_pip_deps", + requirements_lock = "//tools:requirements.txt", +) + +load("@tools_pip_deps//:requirements.bzl", "install_deps") + +install_deps() diff --git a/third_party/benchmark/WORKSPACE.bzlmod b/third_party/benchmark/WORKSPACE.bzlmod new file mode 100644 index 0000000..9526376 --- /dev/null +++ b/third_party/benchmark/WORKSPACE.bzlmod @@ -0,0 +1,2 @@ +# This file marks the root of the Bazel workspace. +# See MODULE.bazel for dependencies and setup. diff --git a/third_party/benchmark/_config.yml b/third_party/benchmark/_config.yml new file mode 100644 index 0000000..1fa5ff8 --- /dev/null +++ b/third_party/benchmark/_config.yml @@ -0,0 +1,2 @@ +theme: jekyll-theme-midnight +markdown: GFM diff --git a/third_party/benchmark/appveyor.yml b/third_party/benchmark/appveyor.yml new file mode 100644 index 0000000..81da955 --- /dev/null +++ b/third_party/benchmark/appveyor.yml @@ -0,0 +1,50 @@ +version: '{build}' + +image: Visual Studio 2017 + +configuration: + - Debug + - Release + +environment: + matrix: + - compiler: msvc-15-seh + generator: "Visual Studio 15 2017" + + - compiler: msvc-15-seh + generator: "Visual Studio 15 2017 Win64" + + - compiler: msvc-14-seh + generator: "Visual Studio 14 2015" + + - compiler: msvc-14-seh + generator: "Visual Studio 14 2015 Win64" + + - compiler: gcc-5.3.0-posix + generator: "MinGW Makefiles" + cxx_path: 'C:\mingw-w64\i686-5.3.0-posix-dwarf-rt_v4-rev0\mingw32\bin' + APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2015 + +matrix: + fast_finish: true + +install: + # git bash conflicts with MinGW makefiles + - if "%generator%"=="MinGW Makefiles" (set "PATH=%PATH:C:\Program Files\Git\usr\bin;=%") + - if not "%cxx_path%"=="" (set "PATH=%PATH%;%cxx_path%") + +build_script: + - md _build -Force + - cd _build + - echo %configuration% + - cmake -G "%generator%" "-DCMAKE_BUILD_TYPE=%configuration%" -DBENCHMARK_DOWNLOAD_DEPENDENCIES=ON .. + - cmake --build . --config %configuration% + +test_script: + - ctest --build-config %configuration% --timeout 300 --output-on-failure + +artifacts: + - path: '_build/CMakeFiles/*.log' + name: logs + - path: '_build/Testing/**/*.xml' + name: test_results diff --git a/third_party/benchmark/bazel/benchmark_deps.bzl b/third_party/benchmark/bazel/benchmark_deps.bzl new file mode 100644 index 0000000..cb908cd --- /dev/null +++ b/third_party/benchmark/bazel/benchmark_deps.bzl @@ -0,0 +1,62 @@ +""" +This file contains the Bazel build dependencies for Google Benchmark (both C++ source and Python bindings). +""" + +load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def benchmark_deps(): + """Loads dependencies required to build Google Benchmark.""" + + if "bazel_skylib" not in native.existing_rules(): + http_archive( + name = "bazel_skylib", + sha256 = "cd55a062e763b9349921f0f5db8c3933288dc8ba4f76dd9416aac68acee3cb94", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz", + ], + ) + + if "rules_foreign_cc" not in native.existing_rules(): + http_archive( + name = "rules_foreign_cc", + sha256 = "476303bd0f1b04cc311fc258f1708a5f6ef82d3091e53fd1977fa20383425a6a", + strip_prefix = "rules_foreign_cc-0.10.1", + url = "https://github.com/bazelbuild/rules_foreign_cc/releases/download/0.10.1/rules_foreign_cc-0.10.1.tar.gz", + ) + + if "rules_python" not in native.existing_rules(): + http_archive( + name = "rules_python", + sha256 = "e85ae30de33625a63eca7fc40a94fea845e641888e52f32b6beea91e8b1b2793", + strip_prefix = "rules_python-0.27.1", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.27.1/rules_python-0.27.1.tar.gz", + ) + + if "com_google_googletest" not in native.existing_rules(): + new_git_repository( + name = "com_google_googletest", + remote = "https://github.com/google/googletest.git", + tag = "release-1.12.1", + ) + + if "nanobind" not in native.existing_rules(): + new_git_repository( + name = "nanobind", + remote = "https://github.com/wjakob/nanobind.git", + tag = "v1.9.2", + build_file = "@//bindings/python:nanobind.BUILD", + recursive_init_submodules = True, + ) + + if "libpfm" not in native.existing_rules(): + # Downloaded from v4.9.0 tag at https://sourceforge.net/p/perfmon2/libpfm4/ref/master/tags/ + http_archive( + name = "libpfm", + build_file = str(Label("//tools:libpfm.BUILD.bazel")), + sha256 = "5da5f8872bde14b3634c9688d980f68bda28b510268723cc12973eedbab9fecc", + type = "tar.gz", + strip_prefix = "libpfm-4.11.0", + urls = ["https://sourceforge.net/projects/perfmon2/files/libpfm4/libpfm-4.11.0.tar.gz/download"], + ) diff --git a/third_party/benchmark/bindings/python/google_benchmark/BUILD b/third_party/benchmark/bindings/python/google_benchmark/BUILD new file mode 100644 index 0000000..30e3893 --- /dev/null +++ b/third_party/benchmark/bindings/python/google_benchmark/BUILD @@ -0,0 +1,33 @@ +load("@nanobind_bazel//:build_defs.bzl", "nanobind_extension", "nanobind_stubgen") + +py_library( + name = "google_benchmark", + srcs = ["__init__.py"], + visibility = ["//visibility:public"], + deps = [ + ":_benchmark", + ], +) + +nanobind_extension( + name = "_benchmark", + srcs = ["benchmark.cc"], + deps = ["//:benchmark"], +) + +nanobind_stubgen( + name = "benchmark_stubgen", + marker_file = "bindings/python/google_benchmark/py.typed", + module = ":_benchmark", +) + +py_test( + name = "example", + srcs = ["example.py"], + python_version = "PY3", + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + ":google_benchmark", + ], +) diff --git a/third_party/benchmark/bindings/python/google_benchmark/__init__.py b/third_party/benchmark/bindings/python/google_benchmark/__init__.py new file mode 100644 index 0000000..7006352 --- /dev/null +++ b/third_party/benchmark/bindings/python/google_benchmark/__init__.py @@ -0,0 +1,142 @@ +# Copyright 2020 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python benchmarking utilities. + +Example usage: + import google_benchmark as benchmark + + @benchmark.register + def my_benchmark(state): + ... # Code executed outside `while` loop is not timed. + + while state: + ... # Code executed within `while` loop is timed. + + if __name__ == '__main__': + benchmark.main() +""" + +import atexit + +from absl import app + +from google_benchmark import _benchmark +from google_benchmark._benchmark import ( + Counter as Counter, + State as State, + kMicrosecond as kMicrosecond, + kMillisecond as kMillisecond, + kNanosecond as kNanosecond, + kSecond as kSecond, + o1 as o1, + oAuto as oAuto, + oLambda as oLambda, + oLogN as oLogN, + oN as oN, + oNCubed as oNCubed, + oNLogN as oNLogN, + oNone as oNone, + oNSquared as oNSquared, +) + +__version__ = "1.9.1" + + +class __OptionMaker: + """A stateless class to collect benchmark options. + + Collect all decorator calls like @option.range(start=0, limit=1<<5). + """ + + class Options: + """Pure data class to store options calls, along with the benchmarked function.""" + + def __init__(self, func): + self.func = func + self.builder_calls = [] + + @classmethod + def make(cls, func_or_options): + """Make Options from Options or the benchmarked function.""" + if isinstance(func_or_options, cls.Options): + return func_or_options + return cls.Options(func_or_options) + + def __getattr__(self, builder_name): + """Append option call in the Options.""" + + # The function that get returned on @option.range(start=0, limit=1<<5). + def __builder_method(*args, **kwargs): + # The decorator that get called, either with the benchmared function + # or the previous Options + def __decorator(func_or_options): + options = self.make(func_or_options) + options.builder_calls.append((builder_name, args, kwargs)) + # The decorator returns Options so it is not technically a decorator + # and needs a final call to @register + return options + + return __decorator + + return __builder_method + + +# Alias for nicer API. +# We have to instantiate an object, even if stateless, to be able to use __getattr__ +# on option.range +option = __OptionMaker() + + +def register(undefined=None, *, name=None): + """Register function for benchmarking.""" + if undefined is None: + # Decorator is called without parenthesis so we return a decorator + return lambda f: register(f, name=name) + + # We have either the function to benchmark (simple case) or an instance of Options + # (@option._ case). + options = __OptionMaker.make(undefined) + + if name is None: + name = options.func.__name__ + + # We register the benchmark and reproduce all the @option._ calls onto the + # benchmark builder pattern + benchmark = _benchmark.RegisterBenchmark(name, options.func) + for name, args, kwargs in options.builder_calls[::-1]: + getattr(benchmark, name)(*args, **kwargs) + + # return the benchmarked function because the decorator does not modify it + return options.func + + +def _flags_parser(argv): + argv = _benchmark.Initialize(argv) + return app.parse_flags_with_usage(argv) + + +def _run_benchmarks(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + return _benchmark.RunSpecifiedBenchmarks() + + +def main(argv=None): + return app.run(_run_benchmarks, argv=argv, flags_parser=_flags_parser) + + +# Methods for use with custom main function. +initialize = _benchmark.Initialize +run_benchmarks = _benchmark.RunSpecifiedBenchmarks +atexit.register(_benchmark.ClearRegisteredBenchmarks) diff --git a/third_party/benchmark/bindings/python/google_benchmark/benchmark.cc b/third_party/benchmark/bindings/python/google_benchmark/benchmark.cc new file mode 100644 index 0000000..a935822 --- /dev/null +++ b/third_party/benchmark/bindings/python/google_benchmark/benchmark.cc @@ -0,0 +1,184 @@ +// Benchmark for Python. + +#include "benchmark/benchmark.h" + +#include "nanobind/nanobind.h" +#include "nanobind/operators.h" +#include "nanobind/stl/bind_map.h" +#include "nanobind/stl/string.h" +#include "nanobind/stl/vector.h" + +NB_MAKE_OPAQUE(benchmark::UserCounters); + +namespace { +namespace nb = nanobind; + +std::vector Initialize(const std::vector& argv) { + // The `argv` pointers here become invalid when this function returns, but + // benchmark holds the pointer to `argv[0]`. We create a static copy of it + // so it persists, and replace the pointer below. + static std::string executable_name(argv[0]); + std::vector ptrs; + ptrs.reserve(argv.size()); + for (auto& arg : argv) { + ptrs.push_back(const_cast(arg.c_str())); + } + ptrs[0] = const_cast(executable_name.c_str()); + int argc = static_cast(argv.size()); + benchmark::Initialize(&argc, ptrs.data()); + std::vector remaining_argv; + remaining_argv.reserve(argc); + for (int i = 0; i < argc; ++i) { + remaining_argv.emplace_back(ptrs[i]); + } + return remaining_argv; +} + +benchmark::internal::Benchmark* RegisterBenchmark(const std::string& name, + nb::callable f) { + return benchmark::RegisterBenchmark( + name, [f](benchmark::State& state) { f(&state); }); +} + +NB_MODULE(_benchmark, m) { + + using benchmark::TimeUnit; + nb::enum_(m, "TimeUnit") + .value("kNanosecond", TimeUnit::kNanosecond) + .value("kMicrosecond", TimeUnit::kMicrosecond) + .value("kMillisecond", TimeUnit::kMillisecond) + .value("kSecond", TimeUnit::kSecond) + .export_values(); + + using benchmark::BigO; + nb::enum_(m, "BigO") + .value("oNone", BigO::oNone) + .value("o1", BigO::o1) + .value("oN", BigO::oN) + .value("oNSquared", BigO::oNSquared) + .value("oNCubed", BigO::oNCubed) + .value("oLogN", BigO::oLogN) + .value("oNLogN", BigO::oNLogN) + .value("oAuto", BigO::oAuto) + .value("oLambda", BigO::oLambda) + .export_values(); + + using benchmark::internal::Benchmark; + nb::class_(m, "Benchmark") + // For methods returning a pointer to the current object, reference + // return policy is used to ask nanobind not to take ownership of the + // returned object and avoid calling delete on it. + // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies + // + // For methods taking a const std::vector<...>&, a copy is created + // because a it is bound to a Python list. + // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + .def("unit", &Benchmark::Unit, nb::rv_policy::reference) + .def("arg", &Benchmark::Arg, nb::rv_policy::reference) + .def("args", &Benchmark::Args, nb::rv_policy::reference) + .def("range", &Benchmark::Range, nb::rv_policy::reference, + nb::arg("start"), nb::arg("limit")) + .def("dense_range", &Benchmark::DenseRange, + nb::rv_policy::reference, nb::arg("start"), + nb::arg("limit"), nb::arg("step") = 1) + .def("ranges", &Benchmark::Ranges, nb::rv_policy::reference) + .def("args_product", &Benchmark::ArgsProduct, + nb::rv_policy::reference) + .def("arg_name", &Benchmark::ArgName, nb::rv_policy::reference) + .def("arg_names", &Benchmark::ArgNames, + nb::rv_policy::reference) + .def("range_pair", &Benchmark::RangePair, + nb::rv_policy::reference, nb::arg("lo1"), nb::arg("hi1"), + nb::arg("lo2"), nb::arg("hi2")) + .def("range_multiplier", &Benchmark::RangeMultiplier, + nb::rv_policy::reference) + .def("min_time", &Benchmark::MinTime, nb::rv_policy::reference) + .def("min_warmup_time", &Benchmark::MinWarmUpTime, + nb::rv_policy::reference) + .def("iterations", &Benchmark::Iterations, + nb::rv_policy::reference) + .def("repetitions", &Benchmark::Repetitions, + nb::rv_policy::reference) + .def("report_aggregates_only", &Benchmark::ReportAggregatesOnly, + nb::rv_policy::reference, nb::arg("value") = true) + .def("display_aggregates_only", &Benchmark::DisplayAggregatesOnly, + nb::rv_policy::reference, nb::arg("value") = true) + .def("measure_process_cpu_time", &Benchmark::MeasureProcessCPUTime, + nb::rv_policy::reference) + .def("use_real_time", &Benchmark::UseRealTime, + nb::rv_policy::reference) + .def("use_manual_time", &Benchmark::UseManualTime, + nb::rv_policy::reference) + .def( + "complexity", + (Benchmark * (Benchmark::*)(benchmark::BigO)) & Benchmark::Complexity, + nb::rv_policy::reference, + nb::arg("complexity") = benchmark::oAuto); + + using benchmark::Counter; + nb::class_ py_counter(m, "Counter"); + + nb::enum_(py_counter, "Flags", nb::is_arithmetic(), nb::is_flag()) + .value("kDefaults", Counter::Flags::kDefaults) + .value("kIsRate", Counter::Flags::kIsRate) + .value("kAvgThreads", Counter::Flags::kAvgThreads) + .value("kAvgThreadsRate", Counter::Flags::kAvgThreadsRate) + .value("kIsIterationInvariant", Counter::Flags::kIsIterationInvariant) + .value("kIsIterationInvariantRate", + Counter::Flags::kIsIterationInvariantRate) + .value("kAvgIterations", Counter::Flags::kAvgIterations) + .value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate) + .value("kInvert", Counter::Flags::kInvert) + .export_values(); + + nb::enum_(py_counter, "OneK") + .value("kIs1000", Counter::OneK::kIs1000) + .value("kIs1024", Counter::OneK::kIs1024) + .export_values(); + + py_counter + .def(nb::init(), + nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults, + nb::arg("k") = Counter::kIs1000) + .def("__init__", + ([](Counter* c, double value) { new (c) Counter(value); })) + .def_rw("value", &Counter::value) + .def_rw("flags", &Counter::flags) + .def_rw("oneK", &Counter::oneK) + .def(nb::init_implicit()); + + nb::implicitly_convertible(); + + nb::bind_map(m, "UserCounters"); + + using benchmark::State; + nb::class_(m, "State") + .def("__bool__", &State::KeepRunning) + .def_prop_ro("keep_running", &State::KeepRunning) + .def("pause_timing", &State::PauseTiming) + .def("resume_timing", &State::ResumeTiming) + .def("skip_with_error", &State::SkipWithError) + .def_prop_ro("error_occurred", &State::error_occurred) + .def("set_iteration_time", &State::SetIterationTime) + .def_prop_rw("bytes_processed", &State::bytes_processed, + &State::SetBytesProcessed) + .def_prop_rw("complexity_n", &State::complexity_length_n, + &State::SetComplexityN) + .def_prop_rw("items_processed", &State::items_processed, + &State::SetItemsProcessed) + .def("set_label", &State::SetLabel) + .def("range", &State::range, nb::arg("pos") = 0) + .def_prop_ro("iterations", &State::iterations) + .def_prop_ro("name", &State::name) + .def_rw("counters", &State::counters) + .def_prop_ro("thread_index", &State::thread_index) + .def_prop_ro("threads", &State::threads); + + m.def("Initialize", Initialize); + m.def("RegisterBenchmark", RegisterBenchmark, + nb::rv_policy::reference); + m.def("RunSpecifiedBenchmarks", + []() { benchmark::RunSpecifiedBenchmarks(); }); + m.def("ClearRegisteredBenchmarks", benchmark::ClearRegisteredBenchmarks); +}; +} // namespace diff --git a/third_party/benchmark/bindings/python/google_benchmark/example.py b/third_party/benchmark/bindings/python/google_benchmark/example.py new file mode 100644 index 0000000..b92245e --- /dev/null +++ b/third_party/benchmark/bindings/python/google_benchmark/example.py @@ -0,0 +1,140 @@ +# Copyright 2020 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example of Python using C++ benchmark framework. + +To run this example, you must first install the `google_benchmark` Python package. + +To install using `setup.py`, download and extract the `google_benchmark` source. +In the extracted directory, execute: + python setup.py install +""" + +import random +import time + +import google_benchmark as benchmark +from google_benchmark import Counter + + +@benchmark.register +def empty(state): + while state: + pass + + +@benchmark.register +def sum_million(state): + while state: + sum(range(1_000_000)) + + +@benchmark.register +def pause_timing(state): + """Pause timing every iteration.""" + while state: + # Construct a list of random ints every iteration without timing it + state.pause_timing() + random_list = [random.randint(0, 100) for _ in range(100)] + state.resume_timing() + # Time the in place sorting algorithm + random_list.sort() + + +@benchmark.register +def skipped(state): + if True: # Test some predicate here. + state.skip_with_error("some error") + return # NOTE: You must explicitly return, or benchmark will continue. + + ... # Benchmark code would be here. + + +@benchmark.register +@benchmark.option.use_manual_time() +def manual_timing(state): + while state: + # Manually count Python CPU time + start = time.perf_counter() # perf_counter_ns() in Python 3.7+ + # Something to benchmark + time.sleep(0.01) + end = time.perf_counter() + state.set_iteration_time(end - start) + + +@benchmark.register +def custom_counters(state): + """Collect custom metric using benchmark.Counter.""" + num_foo = 0.0 + while state: + # Benchmark some code here + pass + # Collect some custom metric named foo + num_foo += 0.13 + + # Automatic Counter from numbers. + state.counters["foo"] = num_foo + # Set a counter as a rate. + state.counters["foo_rate"] = Counter(num_foo, Counter.kIsRate) + # Set a counter as an inverse of rate. + state.counters["foo_inv_rate"] = Counter( + num_foo, Counter.kIsRate | Counter.kInvert + ) + # Set a counter as a thread-average quantity. + state.counters["foo_avg"] = Counter(num_foo, Counter.kAvgThreads) + # There's also a combined flag: + state.counters["foo_avg_rate"] = Counter(num_foo, Counter.kAvgThreadsRate) + + +@benchmark.register +@benchmark.option.measure_process_cpu_time() +@benchmark.option.use_real_time() +def with_options(state): + while state: + sum(range(1_000_000)) + + +@benchmark.register(name="sum_million_microseconds") +@benchmark.option.unit(benchmark.kMicrosecond) +def with_options2(state): + while state: + sum(range(1_000_000)) + + +@benchmark.register +@benchmark.option.arg(100) +@benchmark.option.arg(1000) +def passing_argument(state): + while state: + sum(range(state.range(0))) + + +@benchmark.register +@benchmark.option.range(8, limit=8 << 10) +def using_range(state): + while state: + sum(range(state.range(0))) + + +@benchmark.register +@benchmark.option.range_multiplier(2) +@benchmark.option.range(1 << 10, 1 << 18) +@benchmark.option.complexity(benchmark.oN) +def computing_complexity(state): + while state: + sum(range(state.range(0))) + state.complexity_n = state.range(0) + + +if __name__ == "__main__": + benchmark.main() diff --git a/third_party/benchmark/cmake/AddCXXCompilerFlag.cmake b/third_party/benchmark/cmake/AddCXXCompilerFlag.cmake new file mode 100644 index 0000000..858589e --- /dev/null +++ b/third_party/benchmark/cmake/AddCXXCompilerFlag.cmake @@ -0,0 +1,78 @@ +# - Adds a compiler flag if it is supported by the compiler +# +# This function checks that the supplied compiler flag is supported and then +# adds it to the corresponding compiler flags +# +# add_cxx_compiler_flag( []) +# +# - Example +# +# include(AddCXXCompilerFlag) +# add_cxx_compiler_flag(-Wall) +# add_cxx_compiler_flag(-no-strict-aliasing RELEASE) +# Requires CMake 2.6+ + +if(__add_cxx_compiler_flag) + return() +endif() +set(__add_cxx_compiler_flag INCLUDED) + +include(CheckCXXCompilerFlag) + +function(mangle_compiler_flag FLAG OUTPUT) + string(TOUPPER "HAVE_CXX_FLAG_${FLAG}" SANITIZED_FLAG) + string(REPLACE "+" "X" SANITIZED_FLAG ${SANITIZED_FLAG}) + string(REGEX REPLACE "[^A-Za-z_0-9]" "_" SANITIZED_FLAG ${SANITIZED_FLAG}) + string(REGEX REPLACE "_+" "_" SANITIZED_FLAG ${SANITIZED_FLAG}) + set(${OUTPUT} "${SANITIZED_FLAG}" PARENT_SCOPE) +endfunction(mangle_compiler_flag) + +function(add_cxx_compiler_flag FLAG) + mangle_compiler_flag("${FLAG}" MANGLED_FLAG) + set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${FLAG}") + check_cxx_compiler_flag("${FLAG}" ${MANGLED_FLAG}) + set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}") + if(${MANGLED_FLAG}) + if(ARGC GREATER 1) + set(VARIANT ${ARGV1}) + string(TOUPPER "_${VARIANT}" VARIANT) + else() + set(VARIANT "") + endif() + set(CMAKE_CXX_FLAGS${VARIANT} "${CMAKE_CXX_FLAGS${VARIANT}} ${BENCHMARK_CXX_FLAGS${VARIANT}} ${FLAG}" PARENT_SCOPE) + endif() +endfunction() + +function(add_required_cxx_compiler_flag FLAG) + mangle_compiler_flag("${FLAG}" MANGLED_FLAG) + set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${FLAG}") + check_cxx_compiler_flag("${FLAG}" ${MANGLED_FLAG}) + set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}") + if(${MANGLED_FLAG}) + if(ARGC GREATER 1) + set(VARIANT ${ARGV1}) + string(TOUPPER "_${VARIANT}" VARIANT) + else() + set(VARIANT "") + endif() + set(CMAKE_CXX_FLAGS${VARIANT} "${CMAKE_CXX_FLAGS${VARIANT}} ${FLAG}" PARENT_SCOPE) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${FLAG}" PARENT_SCOPE) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${FLAG}" PARENT_SCOPE) + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} ${FLAG}" PARENT_SCOPE) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${FLAG}" PARENT_SCOPE) + else() + message(FATAL_ERROR "Required flag '${FLAG}' is not supported by the compiler") + endif() +endfunction() + +function(check_cxx_warning_flag FLAG) + mangle_compiler_flag("${FLAG}" MANGLED_FLAG) + set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}") + # Add -Werror to ensure the compiler generates an error if the warning flag + # doesn't exist. + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -Werror ${FLAG}") + check_cxx_compiler_flag("${FLAG}" ${MANGLED_FLAG}) + set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}") +endfunction() diff --git a/third_party/benchmark/cmake/CXXFeatureCheck.cmake b/third_party/benchmark/cmake/CXXFeatureCheck.cmake new file mode 100644 index 0000000..e514826 --- /dev/null +++ b/third_party/benchmark/cmake/CXXFeatureCheck.cmake @@ -0,0 +1,82 @@ +# - Compile and run code to check for C++ features +# +# This functions compiles a source file under the `cmake` folder +# and adds the corresponding `HAVE_[FILENAME]` flag to the CMake +# environment +# +# cxx_feature_check( []) +# +# - Example +# +# include(CXXFeatureCheck) +# cxx_feature_check(STD_REGEX) +# Requires CMake 2.8.12+ + +if(__cxx_feature_check) + return() +endif() +set(__cxx_feature_check INCLUDED) + +option(CXXFEATURECHECK_DEBUG OFF) + +function(cxx_feature_check FILE) + string(TOLOWER ${FILE} FILE) + string(TOUPPER ${FILE} VAR) + string(TOUPPER "HAVE_${VAR}" FEATURE) + if (DEFINED HAVE_${VAR}) + set(HAVE_${VAR} 1 PARENT_SCOPE) + add_definitions(-DHAVE_${VAR}) + return() + endif() + + set(FEATURE_CHECK_CMAKE_FLAGS ${BENCHMARK_CXX_LINKER_FLAGS}) + if (ARGC GREATER 1) + message(STATUS "Enabling additional flags: ${ARGV1}") + list(APPEND FEATURE_CHECK_CMAKE_FLAGS ${ARGV1}) + endif() + + if (NOT DEFINED COMPILE_${FEATURE}) + if(CMAKE_CROSSCOMPILING) + message(STATUS "Cross-compiling to test ${FEATURE}") + try_compile(COMPILE_${FEATURE} + ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/cmake/${FILE}.cpp + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED ON + CMAKE_FLAGS ${FEATURE_CHECK_CMAKE_FLAGS} + LINK_LIBRARIES ${BENCHMARK_CXX_LIBRARIES} + OUTPUT_VARIABLE COMPILE_OUTPUT_VAR) + if(COMPILE_${FEATURE}) + message(WARNING + "If you see build failures due to cross compilation, try setting HAVE_${VAR} to 0") + set(RUN_${FEATURE} 0 CACHE INTERNAL "") + else() + set(RUN_${FEATURE} 1 CACHE INTERNAL "") + endif() + else() + message(STATUS "Compiling and running to test ${FEATURE}") + try_run(RUN_${FEATURE} COMPILE_${FEATURE} + ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/cmake/${FILE}.cpp + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED ON + CMAKE_FLAGS ${FEATURE_CHECK_CMAKE_FLAGS} + LINK_LIBRARIES ${BENCHMARK_CXX_LIBRARIES} + COMPILE_OUTPUT_VARIABLE COMPILE_OUTPUT_VAR) + endif() + endif() + + if(RUN_${FEATURE} EQUAL 0) + message(STATUS "Performing Test ${FEATURE} -- success") + set(HAVE_${VAR} 1 PARENT_SCOPE) + add_definitions(-DHAVE_${VAR}) + else() + if(NOT COMPILE_${FEATURE}) + if(CXXFEATURECHECK_DEBUG) + message(STATUS "Performing Test ${FEATURE} -- failed to compile: ${COMPILE_OUTPUT_VAR}") + else() + message(STATUS "Performing Test ${FEATURE} -- failed to compile") + endif() + else() + message(STATUS "Performing Test ${FEATURE} -- compiled but failed to run") + endif() + endif() +endfunction() diff --git a/third_party/benchmark/cmake/Config.cmake.in b/third_party/benchmark/cmake/Config.cmake.in new file mode 100644 index 0000000..3659cfa --- /dev/null +++ b/third_party/benchmark/cmake/Config.cmake.in @@ -0,0 +1,11 @@ +@PACKAGE_INIT@ + +include (CMakeFindDependencyMacro) + +find_dependency (Threads) + +if (@BENCHMARK_ENABLE_LIBPFM@) + find_dependency (PFM) +endif() + +include("${CMAKE_CURRENT_LIST_DIR}/@targets_export_name@.cmake") diff --git a/third_party/benchmark/cmake/GetGitVersion.cmake b/third_party/benchmark/cmake/GetGitVersion.cmake new file mode 100644 index 0000000..b021010 --- /dev/null +++ b/third_party/benchmark/cmake/GetGitVersion.cmake @@ -0,0 +1,36 @@ +# - Returns a version string from Git tags +# +# This function inspects the annotated git tags for the project and returns a string +# into a CMake variable +# +# get_git_version() +# +# - Example +# +# include(GetGitVersion) +# get_git_version(GIT_VERSION) +# +# Requires CMake 2.8.11+ +find_package(Git) + +if(__get_git_version) + return() +endif() +set(__get_git_version INCLUDED) + +function(get_git_version var) + if(GIT_EXECUTABLE) + execute_process(COMMAND ${GIT_EXECUTABLE} describe --tags --match "v[0-9]*.[0-9]*.[0-9]*" --abbrev=8 --dirty + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + RESULT_VARIABLE status + OUTPUT_VARIABLE GIT_VERSION + ERROR_QUIET) + if(status) + set(GIT_VERSION "v0.0.0") + endif() + else() + set(GIT_VERSION "v0.0.0") + endif() + + set(${var} ${GIT_VERSION} PARENT_SCOPE) +endfunction() diff --git a/third_party/benchmark/cmake/GoogleTest.cmake b/third_party/benchmark/cmake/GoogleTest.cmake new file mode 100644 index 0000000..e66e9d1 --- /dev/null +++ b/third_party/benchmark/cmake/GoogleTest.cmake @@ -0,0 +1,58 @@ +# Download and unpack googletest at configure time +set(GOOGLETEST_PREFIX "${benchmark_BINARY_DIR}/third_party/googletest") +configure_file(${benchmark_SOURCE_DIR}/cmake/GoogleTest.cmake.in ${GOOGLETEST_PREFIX}/CMakeLists.txt @ONLY) + +set(GOOGLETEST_PATH "${CMAKE_CURRENT_SOURCE_DIR}/googletest" CACHE PATH "") # Mind the quotes +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" + -DALLOW_DOWNLOADING_GOOGLETEST=${BENCHMARK_DOWNLOAD_DEPENDENCIES} -DGOOGLETEST_PATH:PATH=${GOOGLETEST_PATH} . + RESULT_VARIABLE result + WORKING_DIRECTORY ${GOOGLETEST_PREFIX} +) + +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() + +execute_process( + COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${GOOGLETEST_PREFIX} +) + +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +include(${GOOGLETEST_PREFIX}/googletest-paths.cmake) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${GOOGLETEST_SOURCE_DIR} + ${GOOGLETEST_BINARY_DIR} + EXCLUDE_FROM_ALL) + +# googletest doesn't seem to want to stay build warning clean so let's not hurt ourselves. +if (MSVC) + target_compile_options(gtest PRIVATE "/wd4244" "/wd4722") + target_compile_options(gtest_main PRIVATE "/wd4244" "/wd4722") + target_compile_options(gmock PRIVATE "/wd4244" "/wd4722") + target_compile_options(gmock_main PRIVATE "/wd4244" "/wd4722") +else() + target_compile_options(gtest PRIVATE "-w") + target_compile_options(gtest_main PRIVATE "-w") + target_compile_options(gmock PRIVATE "-w") + target_compile_options(gmock_main PRIVATE "-w") +endif() + +if(NOT DEFINED GTEST_COMPILE_COMMANDS) + set(GTEST_COMPILE_COMMANDS ON) +endif() + +set_target_properties(gtest PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES $ EXPORT_COMPILE_COMMANDS ${GTEST_COMPILE_COMMANDS}) +set_target_properties(gtest_main PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES $ EXPORT_COMPILE_COMMANDS ${GTEST_COMPILE_COMMANDS}) +set_target_properties(gmock PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES $ EXPORT_COMPILE_COMMANDS ${GTEST_COMPILE_COMMANDS}) +set_target_properties(gmock_main PROPERTIES INTERFACE_SYSTEM_INCLUDE_DIRECTORIES $ EXPORT_COMPILE_COMMANDS ${GTEST_COMPILE_COMMANDS}) diff --git a/third_party/benchmark/cmake/GoogleTest.cmake.in b/third_party/benchmark/cmake/GoogleTest.cmake.in new file mode 100644 index 0000000..c791446 --- /dev/null +++ b/third_party/benchmark/cmake/GoogleTest.cmake.in @@ -0,0 +1,59 @@ +cmake_minimum_required (VERSION 3.13...3.22) + +project(googletest-download NONE) + +# Enable ExternalProject CMake module +include(ExternalProject) + +option(ALLOW_DOWNLOADING_GOOGLETEST "If googletest src tree is not found in location specified by GOOGLETEST_PATH, do fetch the archive from internet" OFF) +set(GOOGLETEST_PATH "/usr/src/googletest" CACHE PATH + "Path to the googletest root tree. Should contain googletest and googlemock subdirs. And CMakeLists.txt in root, and in both of these subdirs") + +# Download and install GoogleTest + +message(STATUS "Looking for Google Test sources") +message(STATUS "Looking for Google Test sources in ${GOOGLETEST_PATH}") +if(EXISTS "${GOOGLETEST_PATH}" AND IS_DIRECTORY "${GOOGLETEST_PATH}" AND EXISTS "${GOOGLETEST_PATH}/CMakeLists.txt" AND + EXISTS "${GOOGLETEST_PATH}/googletest" AND IS_DIRECTORY "${GOOGLETEST_PATH}/googletest" AND EXISTS "${GOOGLETEST_PATH}/googletest/CMakeLists.txt" AND + EXISTS "${GOOGLETEST_PATH}/googlemock" AND IS_DIRECTORY "${GOOGLETEST_PATH}/googlemock" AND EXISTS "${GOOGLETEST_PATH}/googlemock/CMakeLists.txt") + message(STATUS "Found Google Test in ${GOOGLETEST_PATH}") + + ExternalProject_Add( + googletest + PREFIX "${CMAKE_BINARY_DIR}" + DOWNLOAD_DIR "${CMAKE_BINARY_DIR}/download" + SOURCE_DIR "${GOOGLETEST_PATH}" # use existing src dir. + BINARY_DIR "${CMAKE_BINARY_DIR}/build" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + if(NOT ALLOW_DOWNLOADING_GOOGLETEST) + message(SEND_ERROR "Did not find Google Test sources! Either pass correct path in GOOGLETEST_PATH, or enable BENCHMARK_DOWNLOAD_DEPENDENCIES, or disable BENCHMARK_USE_BUNDLED_GTEST, or disable BENCHMARK_ENABLE_GTEST_TESTS / BENCHMARK_ENABLE_TESTING.") + return() + else() + message(WARNING "Did not find Google Test sources! Fetching from web...") + ExternalProject_Add( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG "v1.14.0" + PREFIX "${CMAKE_BINARY_DIR}" + STAMP_DIR "${CMAKE_BINARY_DIR}/stamp" + DOWNLOAD_DIR "${CMAKE_BINARY_DIR}/download" + SOURCE_DIR "${CMAKE_BINARY_DIR}/src" + BINARY_DIR "${CMAKE_BINARY_DIR}/build" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) + endif() +endif() + +ExternalProject_Get_Property(googletest SOURCE_DIR BINARY_DIR) +file(WRITE googletest-paths.cmake +"set(GOOGLETEST_SOURCE_DIR \"${SOURCE_DIR}\") +set(GOOGLETEST_BINARY_DIR \"${BINARY_DIR}\") +") diff --git a/third_party/benchmark/cmake/benchmark.pc.in b/third_party/benchmark/cmake/benchmark.pc.in new file mode 100644 index 0000000..bbed29d --- /dev/null +++ b/third_party/benchmark/cmake/benchmark.pc.in @@ -0,0 +1,12 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: @PROJECT_NAME@ +Description: Google microbenchmark framework +Version: @NORMALIZED_VERSION@ + +Libs: -L${libdir} -lbenchmark +Libs.private: -lpthread @BENCHMARK_PRIVATE_LINK_LIBRARIES@ +Cflags: -I${includedir} diff --git a/third_party/benchmark/cmake/benchmark_main.pc.in b/third_party/benchmark/cmake/benchmark_main.pc.in new file mode 100644 index 0000000..e9d81a0 --- /dev/null +++ b/third_party/benchmark/cmake/benchmark_main.pc.in @@ -0,0 +1,7 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ + +Name: @PROJECT_NAME@ +Description: Google microbenchmark framework (with main() function) +Version: @NORMALIZED_VERSION@ +Requires: benchmark +Libs: -L${libdir} -lbenchmark_main diff --git a/third_party/benchmark/cmake/gnu_posix_regex.cpp b/third_party/benchmark/cmake/gnu_posix_regex.cpp new file mode 100644 index 0000000..b5b91cd --- /dev/null +++ b/third_party/benchmark/cmake/gnu_posix_regex.cpp @@ -0,0 +1,12 @@ +#include +#include +int main() { + std::string str = "test0159"; + regex_t re; + int ec = regcomp(&re, "^[a-z]+[0-9]+$", REG_EXTENDED | REG_NOSUB); + if (ec != 0) { + return ec; + } + return regexec(&re, str.c_str(), 0, nullptr, 0) ? -1 : 0; +} + diff --git a/third_party/benchmark/cmake/llvm-toolchain.cmake b/third_party/benchmark/cmake/llvm-toolchain.cmake new file mode 100644 index 0000000..fc119e5 --- /dev/null +++ b/third_party/benchmark/cmake/llvm-toolchain.cmake @@ -0,0 +1,8 @@ +find_package(LLVMAr REQUIRED) +set(CMAKE_AR "${LLVMAR_EXECUTABLE}" CACHE FILEPATH "" FORCE) + +find_package(LLVMNm REQUIRED) +set(CMAKE_NM "${LLVMNM_EXECUTABLE}" CACHE FILEPATH "" FORCE) + +find_package(LLVMRanLib REQUIRED) +set(CMAKE_RANLIB "${LLVMRANLIB_EXECUTABLE}" CACHE FILEPATH "" FORCE) diff --git a/third_party/benchmark/cmake/posix_regex.cpp b/third_party/benchmark/cmake/posix_regex.cpp new file mode 100644 index 0000000..466dc62 --- /dev/null +++ b/third_party/benchmark/cmake/posix_regex.cpp @@ -0,0 +1,14 @@ +#include +#include +int main() { + std::string str = "test0159"; + regex_t re; + int ec = regcomp(&re, "^[a-z]+[0-9]+$", REG_EXTENDED | REG_NOSUB); + if (ec != 0) { + return ec; + } + int ret = regexec(&re, str.c_str(), 0, nullptr, 0) ? -1 : 0; + regfree(&re); + return ret; +} + diff --git a/third_party/benchmark/cmake/pthread_affinity.cpp b/third_party/benchmark/cmake/pthread_affinity.cpp new file mode 100644 index 0000000..7b143bc --- /dev/null +++ b/third_party/benchmark/cmake/pthread_affinity.cpp @@ -0,0 +1,16 @@ +#include +int main() { + cpu_set_t set; + CPU_ZERO(&set); + for (int i = 0; i < CPU_SETSIZE; ++i) { + CPU_SET(i, &set); + CPU_CLR(i, &set); + } + pthread_t self = pthread_self(); + int ret; + ret = pthread_getaffinity_np(self, sizeof(set), &set); + if (ret != 0) return ret; + ret = pthread_setaffinity_np(self, sizeof(set), &set); + if (ret != 0) return ret; + return 0; +} diff --git a/third_party/benchmark/cmake/split_list.cmake b/third_party/benchmark/cmake/split_list.cmake new file mode 100644 index 0000000..67aed3f --- /dev/null +++ b/third_party/benchmark/cmake/split_list.cmake @@ -0,0 +1,3 @@ +macro(split_list listname) + string(REPLACE ";" " " ${listname} "${${listname}}") +endmacro() diff --git a/third_party/benchmark/cmake/std_regex.cpp b/third_party/benchmark/cmake/std_regex.cpp new file mode 100644 index 0000000..696f2a2 --- /dev/null +++ b/third_party/benchmark/cmake/std_regex.cpp @@ -0,0 +1,10 @@ +#include +#include +int main() { + const std::string str = "test0159"; + std::regex re; + re = std::regex("^[a-z]+[0-9]+$", + std::regex_constants::extended | std::regex_constants::nosubs); + return std::regex_search(str, re) ? 0 : -1; +} + diff --git a/third_party/benchmark/cmake/steady_clock.cpp b/third_party/benchmark/cmake/steady_clock.cpp new file mode 100644 index 0000000..66d50d1 --- /dev/null +++ b/third_party/benchmark/cmake/steady_clock.cpp @@ -0,0 +1,7 @@ +#include + +int main() { + typedef std::chrono::steady_clock Clock; + Clock::time_point tp = Clock::now(); + ((void)tp); +} diff --git a/third_party/benchmark/cmake/thread_safety_attributes.cpp b/third_party/benchmark/cmake/thread_safety_attributes.cpp new file mode 100644 index 0000000..46161ba --- /dev/null +++ b/third_party/benchmark/cmake/thread_safety_attributes.cpp @@ -0,0 +1,4 @@ +#define HAVE_THREAD_SAFETY_ATTRIBUTES +#include "../src/mutex.h" + +int main() {} diff --git a/third_party/benchmark/docs/AssemblyTests.md b/third_party/benchmark/docs/AssemblyTests.md new file mode 100644 index 0000000..89df7ca --- /dev/null +++ b/third_party/benchmark/docs/AssemblyTests.md @@ -0,0 +1,149 @@ +# Assembly Tests + +The Benchmark library provides a number of functions whose primary +purpose in to affect assembly generation, including `DoNotOptimize` +and `ClobberMemory`. In addition there are other functions, +such as `KeepRunning`, for which generating good assembly is paramount. + +For these functions it's important to have tests that verify the +correctness and quality of the implementation. This requires testing +the code generated by the compiler. + +This document describes how the Benchmark library tests compiler output, +as well as how to properly write new tests. + + +## Anatomy of a Test + +Writing a test has two steps: + +* Write the code you want to generate assembly for. +* Add `// CHECK` lines to match against the verified assembly. + +Example: +```c++ + +// CHECK-LABEL: test_add: +extern "C" int test_add() { + extern int ExternInt; + return ExternInt + 1; + + // CHECK: movl ExternInt(%rip), %eax + // CHECK: addl %eax + // CHECK: ret +} + +``` + +#### LLVM Filecheck + +[LLVM's Filecheck](https://llvm.org/docs/CommandGuide/FileCheck.html) +is used to test the generated assembly against the `// CHECK` lines +specified in the tests source file. Please see the documentation +linked above for information on how to write `CHECK` directives. + +#### Tips and Tricks: + +* Tests should match the minimal amount of output required to establish +correctness. `CHECK` directives don't have to match on the exact next line +after the previous match, so tests should omit checks for unimportant +bits of assembly. ([`CHECK-NEXT`](https://llvm.org/docs/CommandGuide/FileCheck.html#the-check-next-directive) +can be used to ensure a match occurs exactly after the previous match). + +* The tests are compiled with `-O3 -g0`. So we're only testing the +optimized output. + +* The assembly output is further cleaned up using `tools/strip_asm.py`. +This removes comments, assembler directives, and unused labels before +the test is run. + +* The generated and stripped assembly file for a test is output under +`/test/.s` + +* Filecheck supports using [`CHECK` prefixes](https://llvm.org/docs/CommandGuide/FileCheck.html#cmdoption-check-prefixes) +to specify lines that should only match in certain situations. +The Benchmark tests use `CHECK-CLANG` and `CHECK-GNU` for lines that +are only expected to match Clang or GCC's output respectively. Normal +`CHECK` lines match against all compilers. (Note: `CHECK-NOT` and +`CHECK-LABEL` are NOT prefixes. They are versions of non-prefixed +`CHECK` lines) + +* Use `extern "C"` to disable name mangling for specific functions. This +makes them easier to name in the `CHECK` lines. + + +## Problems Writing Portable Tests + +Writing tests which check the code generated by a compiler are +inherently non-portable. Different compilers and even different compiler +versions may generate entirely different code. The Benchmark tests +must tolerate this. + +LLVM Filecheck provides a number of mechanisms to help write +"more portable" tests; including [matching using regular expressions](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-pattern-matching-syntax), +allowing the creation of [named variables](https://llvm.org/docs/CommandGuide/FileCheck.html#filecheck-variables) +for later matching, and [checking non-sequential matches](https://llvm.org/docs/CommandGuide/FileCheck.html#the-check-dag-directive). + +#### Capturing Variables + +For example, say GCC stores a variable in a register but Clang stores +it in memory. To write a test that tolerates both cases we "capture" +the destination of the store, and then use the captured expression +to write the remainder of the test. + +```c++ +// CHECK-LABEL: test_div_no_op_into_shr: +extern "C" void test_div_no_op_into_shr(int value) { + int divisor = 2; + benchmark::DoNotOptimize(divisor); // hide the value from the optimizer + return value / divisor; + + // CHECK: movl $2, [[DEST:.*]] + // CHECK: idivl [[DEST]] + // CHECK: ret +} +``` + +#### Using Regular Expressions to Match Differing Output + +Often tests require testing assembly lines which may subtly differ +between compilers or compiler versions. A common example of this +is matching stack frame addresses. In this case regular expressions +can be used to match the differing bits of output. For example: + + +```c++ +int ExternInt; +struct Point { int x, y, z; }; + +// CHECK-LABEL: test_store_point: +extern "C" void test_store_point() { + Point p{ExternInt, ExternInt, ExternInt}; + benchmark::DoNotOptimize(p); + + // CHECK: movl ExternInt(%rip), %eax + // CHECK: movl %eax, -{{[0-9]+}}(%rsp) + // CHECK: movl %eax, -{{[0-9]+}}(%rsp) + // CHECK: movl %eax, -{{[0-9]+}}(%rsp) + // CHECK: ret +} +``` + + +## Current Requirements and Limitations + +The tests require Filecheck to be installed along the `PATH` of the +build machine. Otherwise the tests will be disabled. + +Additionally, as mentioned in the previous section, codegen tests are +inherently non-portable. Currently the tests are limited to: + +* x86_64 targets. +* Compiled with GCC or Clang + +Further work could be done, at least on a limited basis, to extend the +tests to other architectures and compilers (using `CHECK` prefixes). + +Furthermore, the tests fail for builds which specify additional flags +that modify code generation, including `--coverage` or `-fsanitize=`. + diff --git a/third_party/benchmark/docs/_config.yml b/third_party/benchmark/docs/_config.yml new file mode 100644 index 0000000..32f9f2e --- /dev/null +++ b/third_party/benchmark/docs/_config.yml @@ -0,0 +1,3 @@ +theme: jekyll-theme-minimal +logo: /assets/images/icon_black.png +show_downloads: true diff --git a/third_party/benchmark/docs/assets/images/icon.png b/third_party/benchmark/docs/assets/images/icon.png new file mode 100644 index 0000000..b982604 Binary files /dev/null and b/third_party/benchmark/docs/assets/images/icon.png differ diff --git a/third_party/benchmark/docs/assets/images/icon.xcf b/third_party/benchmark/docs/assets/images/icon.xcf new file mode 100644 index 0000000..f2f0be4 Binary files /dev/null and b/third_party/benchmark/docs/assets/images/icon.xcf differ diff --git a/third_party/benchmark/docs/assets/images/icon_black.png b/third_party/benchmark/docs/assets/images/icon_black.png new file mode 100644 index 0000000..656ae79 Binary files /dev/null and b/third_party/benchmark/docs/assets/images/icon_black.png differ diff --git a/third_party/benchmark/docs/assets/images/icon_black.xcf b/third_party/benchmark/docs/assets/images/icon_black.xcf new file mode 100644 index 0000000..430e7ba Binary files /dev/null and b/third_party/benchmark/docs/assets/images/icon_black.xcf differ diff --git a/third_party/benchmark/docs/dependencies.md b/third_party/benchmark/docs/dependencies.md new file mode 100644 index 0000000..98ce996 --- /dev/null +++ b/third_party/benchmark/docs/dependencies.md @@ -0,0 +1,19 @@ +# Build tool dependency policy + +We follow the [Foundational C++ support policy](https://opensource.google/documentation/policies/cplusplus-support) for our build tools. In +particular the ["Build Systems" section](https://opensource.google/documentation/policies/cplusplus-support#build-systems). + +## CMake + +The current supported version is CMake 3.10 as of 2023-08-10. Most modern +distributions include newer versions, for example: + +* Ubuntu 20.04 provides CMake 3.16.3 +* Debian 11.4 provides CMake 3.18.4 +* Ubuntu 22.04 provides CMake 3.22.1 + +## Python + +The Python bindings require Python 3.10+ as of v1.9.0 (2024-08-16) for installation from PyPI. +Building from source for older versions probably still works, though. See the [user guide](python_bindings.md) for details on how to build from source. +The minimum theoretically supported version is Python 3.8, since the used bindings generator (nanobind) only supports Python 3.8+. diff --git a/third_party/benchmark/docs/index.md b/third_party/benchmark/docs/index.md new file mode 100644 index 0000000..9cada96 --- /dev/null +++ b/third_party/benchmark/docs/index.md @@ -0,0 +1,12 @@ +# Benchmark + +* [Assembly Tests](AssemblyTests.md) +* [Dependencies](dependencies.md) +* [Perf Counters](perf_counters.md) +* [Platform Specific Build Instructions](platform_specific_build_instructions.md) +* [Python Bindings](python_bindings.md) +* [Random Interleaving](random_interleaving.md) +* [Reducing Variance](reducing_variance.md) +* [Releasing](releasing.md) +* [Tools](tools.md) +* [User Guide](user_guide.md) diff --git a/third_party/benchmark/docs/perf_counters.md b/third_party/benchmark/docs/perf_counters.md new file mode 100644 index 0000000..f342092 --- /dev/null +++ b/third_party/benchmark/docs/perf_counters.md @@ -0,0 +1,35 @@ + + +# User-Requested Performance Counters + +When running benchmarks, the user may choose to request collection of +performance counters. This may be useful in investigation scenarios - narrowing +down the cause of a regression; or verifying that the underlying cause of a +performance improvement matches expectations. + +This feature is available if: + +* The benchmark is run on an architecture featuring a Performance Monitoring + Unit (PMU), +* The benchmark is compiled with support for collecting counters. Currently, + this requires [libpfm](http://perfmon2.sourceforge.net/), which is built as a + dependency via Bazel. + +The feature does not require modifying benchmark code. Counter collection is +handled at the boundaries where timer collection is also handled. + +To opt-in: +* If using a Bazel build, add `--define pfm=1` to your build flags +* If using CMake: + * Install `libpfm4-dev`, e.g. `apt-get install libpfm4-dev`. + * Enable the CMake flag `BENCHMARK_ENABLE_LIBPFM` in `CMakeLists.txt`. + +To use, pass a comma-separated list of counter names through the +`--benchmark_perf_counters` flag. The names are decoded through libpfm - meaning, +they are platform specific, but some (e.g. `CYCLES` or `INSTRUCTIONS`) are +mapped by libpfm to platform-specifics - see libpfm +[documentation](http://perfmon2.sourceforge.net/docs.html) for more details. + +The counter values are reported back through the [User Counters](../README.md#custom-counters) +mechanism, meaning, they are available in all the formats (e.g. JSON) supported +by User Counters. diff --git a/third_party/benchmark/docs/platform_specific_build_instructions.md b/third_party/benchmark/docs/platform_specific_build_instructions.md new file mode 100644 index 0000000..2d5d6c4 --- /dev/null +++ b/third_party/benchmark/docs/platform_specific_build_instructions.md @@ -0,0 +1,48 @@ +# Platform Specific Build Instructions + +## Building with GCC + +When the library is built using GCC it is necessary to link with the pthread +library due to how GCC implements `std::thread`. Failing to link to pthread will +lead to runtime exceptions (unless you're using libc++), not linker errors. See +[issue #67](https://github.com/google/benchmark/issues/67) for more details. You +can link to pthread by adding `-pthread` to your linker command. Note, you can +also use `-lpthread`, but there are potential issues with ordering of command +line parameters if you use that. + +On QNX, the pthread library is part of libc and usually included automatically +(see +[`pthread_create()`](https://www.qnx.com/developers/docs/7.1/index.html#com.qnx.doc.neutrino.lib_ref/topic/p/pthread_create.html)). +There's no separate pthread library to link. + +## Building with Visual Studio 2015 or 2017 + +The `shlwapi` library (`-lshlwapi`) is required to support a call to `CPUInfo` which reads the registry. Either add `shlwapi.lib` under `[ Configuration Properties > Linker > Input ]`, or use the following: + +``` +// Alternatively, can add libraries using linker options. +#ifdef _WIN32 +#pragma comment ( lib, "Shlwapi.lib" ) +#ifdef _DEBUG +#pragma comment ( lib, "benchmarkd.lib" ) +#else +#pragma comment ( lib, "benchmark.lib" ) +#endif +#endif +``` + +Can also use the graphical version of CMake: +* Open `CMake GUI`. +* Under `Where to build the binaries`, same path as source plus `build`. +* Under `CMAKE_INSTALL_PREFIX`, same path as source plus `install`. +* Click `Configure`, `Generate`, `Open Project`. +* If build fails, try deleting entire directory and starting again, or unticking options to build less. + +## Building with Intel 2015 Update 1 or Intel System Studio Update 4 + +See instructions for building with Visual Studio. Once built, right click on the solution and change the build to Intel. + +## Building on Solaris + +If you're running benchmarks on solaris, you'll want the kstat library linked in +too (`-lkstat`). \ No newline at end of file diff --git a/third_party/benchmark/docs/python_bindings.md b/third_party/benchmark/docs/python_bindings.md new file mode 100644 index 0000000..d9c5d2d --- /dev/null +++ b/third_party/benchmark/docs/python_bindings.md @@ -0,0 +1,34 @@ +# Building and installing Python bindings + +Python bindings are available as wheels on [PyPI](https://pypi.org/project/google-benchmark/) for importing and +using Google Benchmark directly in Python. +Currently, pre-built wheels exist for macOS (both ARM64 and Intel x86), Linux x86-64 and 64-bit Windows. +Supported Python versions are Python 3.8 - 3.12. + +To install Google Benchmark's Python bindings, run: + +```bash +python -m pip install --upgrade pip # for manylinux2014 support +python -m pip install google-benchmark +``` + +In order to keep your system Python interpreter clean, it is advisable to run these commands in a virtual +environment. See the [official Python documentation](https://docs.python.org/3/library/venv.html) +on how to create virtual environments. + +To build a wheel directly from source, you can follow these steps: +```bash +git clone https://github.com/google/benchmark.git +cd benchmark +# create a virtual environment and activate it +python3 -m venv venv --system-site-packages +source venv/bin/activate # .\venv\Scripts\Activate.ps1 on Windows + +# upgrade Python's system-wide packages +python -m pip install --upgrade pip build +# builds the wheel and stores it in the directory "dist". +python -m build +``` + +NB: Building wheels from source requires Bazel. For platform-specific instructions on how to install Bazel, +refer to the [Bazel installation docs](https://bazel.build/install). diff --git a/third_party/benchmark/docs/random_interleaving.md b/third_party/benchmark/docs/random_interleaving.md new file mode 100644 index 0000000..c083036 --- /dev/null +++ b/third_party/benchmark/docs/random_interleaving.md @@ -0,0 +1,13 @@ + + +# Random Interleaving + +[Random Interleaving](https://github.com/google/benchmark/issues/1051) is a +technique to lower run-to-run variance. It randomly interleaves repetitions of a +microbenchmark with repetitions from other microbenchmarks in the same benchmark +test. Data shows it is able to lower run-to-run variance by +[40%](https://github.com/google/benchmark/issues/1051) on average. + +To use, you mainly need to set `--benchmark_enable_random_interleaving=true`, +and optionally specify non-zero repetition count `--benchmark_repetitions=9` +and optionally decrease the per-repetition time `--benchmark_min_time=0.1`. diff --git a/third_party/benchmark/docs/reducing_variance.md b/third_party/benchmark/docs/reducing_variance.md new file mode 100644 index 0000000..105f96e --- /dev/null +++ b/third_party/benchmark/docs/reducing_variance.md @@ -0,0 +1,98 @@ +# Reducing Variance + + + +## Disabling CPU Frequency Scaling + +If you see this error: + +``` +***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead. +``` + +you might want to disable the CPU frequency scaling while running the +benchmark, as well as consider other ways to stabilize the performance of +your system while benchmarking. + +Exactly how to do this depends on the Linux distribution, +desktop environment, and installed programs. Specific details are a moving +target, so we will not attempt to exhaustively document them here. + +One simple option is to use the `cpupower` program to change the +performance governor to "performance". This tool is maintained along with +the Linux kernel and provided by your distribution. + +It must be run as root, like this: + +```bash +sudo cpupower frequency-set --governor performance +``` + +After this you can verify that all CPUs are using the performance governor +by running this command: + +```bash +cpupower frequency-info -o proc +``` + +The benchmarks you subsequently run will have less variance. + + + +## Reducing Variance in Benchmarks + +The Linux CPU frequency governor [discussed +above](user_guide#disabling-cpu-frequency-scaling) is not the only source +of noise in benchmarks. Some, but not all, of the sources of variance +include: + +1. On multi-core machines not all CPUs/CPU cores/CPU threads run the same + speed, so running a benchmark one time and then again may give a + different result depending on which CPU it ran on. +2. CPU scaling features that run on the CPU, like Intel's Turbo Boost and + AMD Turbo Core and Precision Boost, can temporarily change the CPU + frequency even when the using the "performance" governor on Linux. +3. Context switching between CPUs, or scheduling competition on the CPU the + benchmark is running on. +4. Intel Hyperthreading or AMD SMT causing the same issue as above. +5. Cache effects caused by code running on other CPUs. +6. Non-uniform memory architectures (NUMA). + +These can cause variance in benchmarks results within a single run +(`--benchmark_repetitions=N`) or across multiple runs of the benchmark +program. + +Reducing sources of variance is OS and architecture dependent, which is one +reason some companies maintain machines dedicated to performance testing. + +Some of the easier and effective ways of reducing variance on a typical +Linux workstation are: + +1. Use the performance governor as [discussed +above](user_guide#disabling-cpu-frequency-scaling). +1. Disable processor boosting by: + ```sh + echo 0 | sudo tee /sys/devices/system/cpu/cpufreq/boost + ``` + See the Linux kernel's + [boost.txt](https://www.kernel.org/doc/Documentation/cpu-freq/boost.txt) + for more information. +2. Set the benchmark program's task affinity to a fixed cpu. For example: + ```sh + taskset -c 0 ./mybenchmark + ``` +3. Disabling Hyperthreading/SMT. This can be done in the Bios or using the + `/sys` file system (see the LLVM project's [Benchmarking + tips](https://llvm.org/docs/Benchmarking.html)). +4. Close other programs that do non-trivial things based on timers, such as + your web browser, desktop environment, etc. +5. Reduce the working set of your benchmark to fit within the L1 cache, but + do be aware that this may lead you to optimize for an unrealistic + situation. + +Further resources on this topic: + +1. The LLVM project's [Benchmarking + tips](https://llvm.org/docs/Benchmarking.html). +1. The Arch Wiki [Cpu frequency +scaling](https://wiki.archlinux.org/title/CPU_frequency_scaling) page. diff --git a/third_party/benchmark/docs/releasing.md b/third_party/benchmark/docs/releasing.md new file mode 100644 index 0000000..ab664a8 --- /dev/null +++ b/third_party/benchmark/docs/releasing.md @@ -0,0 +1,38 @@ +# How to release + +* Make sure you're on main and synced to HEAD +* Ensure the project builds and tests run + * `parallel -j0 exec ::: test/*_test` can help ensure everything at least + passes +* Prepare release notes + * `git log $(git describe --abbrev=0 --tags)..HEAD` gives you the list of + commits between the last annotated tag and HEAD + * Pick the most interesting. +* Create one last commit that updates the version saved in `CMakeLists.txt`, `MODULE.bazel`, + and `bindings/python/google_benchmark/__init__.py` to the release version you're creating. + (This version will be used if benchmark is installed from the archive you'll be creating + in the next step.) + +``` +# CMakeLists.txt +project (benchmark VERSION 1.9.0 LANGUAGES CXX) +``` + +``` +# MODULE.bazel +module(name = "com_github_google_benchmark", version="1.9.0") +``` + +``` +# google_benchmark/__init__.py +__version__ = "1.9.0" +``` + +* Create a release through github's interface + * Note this will create a lightweight tag. + * Update this to an annotated tag: + * `git pull --tags` + * `git tag -a -f ` + * `git push --force --tags origin` +* Confirm that the "Build and upload Python wheels" action runs to completion + * Run it manually if it hasn't run. diff --git a/third_party/benchmark/docs/tools.md b/third_party/benchmark/docs/tools.md new file mode 100644 index 0000000..411f41d --- /dev/null +++ b/third_party/benchmark/docs/tools.md @@ -0,0 +1,343 @@ +# Benchmark Tools + +## compare.py + +The `compare.py` can be used to compare the result of benchmarks. + +### Dependencies +The utility relies on the [scipy](https://www.scipy.org) package which can be installed using pip: +```bash +pip3 install -r requirements.txt +``` + +### Displaying aggregates only + +The switch `-a` / `--display_aggregates_only` can be used to control the +displayment of the normal iterations vs the aggregates. When passed, it will +be passthrough to the benchmark binaries to be run, and will be accounted for +in the tool itself; only the aggregates will be displayed, but not normal runs. +It only affects the display, the separate runs will still be used to calculate +the U test. + +### Modes of operation + +There are three modes of operation: + +1. Just compare two benchmarks +The program is invoked like: + +``` bash +$ compare.py benchmarks [benchmark options]... +``` +Where `` and `` either specify a benchmark executable file, or a JSON output file. The type of the input file is automatically detected. If a benchmark executable is specified then the benchmark is run to obtain the results. Otherwise the results are simply loaded from the output file. + +`[benchmark options]` will be passed to the benchmarks invocations. They can be anything that binary accepts, be it either normal `--benchmark_*` parameters, or some custom parameters your binary takes. + +Example output: +``` +$ ./compare.py benchmarks ./a.out ./a.out +RUNNING: ./a.out --benchmark_out=/tmp/tmprBT5nW +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:16:44 +------------------------------------------------------ +Benchmark Time CPU Iterations +------------------------------------------------------ +BM_memcpy/8 36 ns 36 ns 19101577 211.669MB/s +BM_memcpy/64 76 ns 76 ns 9412571 800.199MB/s +BM_memcpy/512 84 ns 84 ns 8249070 5.64771GB/s +BM_memcpy/1024 116 ns 116 ns 6181763 8.19505GB/s +BM_memcpy/8192 643 ns 643 ns 1062855 11.8636GB/s +BM_copy/8 222 ns 222 ns 3137987 34.3772MB/s +BM_copy/64 1608 ns 1608 ns 432758 37.9501MB/s +BM_copy/512 12589 ns 12589 ns 54806 38.7867MB/s +BM_copy/1024 25169 ns 25169 ns 27713 38.8003MB/s +BM_copy/8192 201165 ns 201112 ns 3486 38.8466MB/s +RUNNING: ./a.out --benchmark_out=/tmp/tmpt1wwG_ +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:16:53 +------------------------------------------------------ +Benchmark Time CPU Iterations +------------------------------------------------------ +BM_memcpy/8 36 ns 36 ns 19397903 211.255MB/s +BM_memcpy/64 73 ns 73 ns 9691174 839.635MB/s +BM_memcpy/512 85 ns 85 ns 8312329 5.60101GB/s +BM_memcpy/1024 118 ns 118 ns 6438774 8.11608GB/s +BM_memcpy/8192 656 ns 656 ns 1068644 11.6277GB/s +BM_copy/8 223 ns 223 ns 3146977 34.2338MB/s +BM_copy/64 1611 ns 1611 ns 435340 37.8751MB/s +BM_copy/512 12622 ns 12622 ns 54818 38.6844MB/s +BM_copy/1024 25257 ns 25239 ns 27779 38.6927MB/s +BM_copy/8192 205013 ns 205010 ns 3479 38.108MB/s +Comparing ./a.out to ./a.out +Benchmark Time CPU Time Old Time New CPU Old CPU New +------------------------------------------------------------------------------------------------------ +BM_memcpy/8 +0.0020 +0.0020 36 36 36 36 +BM_memcpy/64 -0.0468 -0.0470 76 73 76 73 +BM_memcpy/512 +0.0081 +0.0083 84 85 84 85 +BM_memcpy/1024 +0.0098 +0.0097 116 118 116 118 +BM_memcpy/8192 +0.0200 +0.0203 643 656 643 656 +BM_copy/8 +0.0046 +0.0042 222 223 222 223 +BM_copy/64 +0.0020 +0.0020 1608 1611 1608 1611 +BM_copy/512 +0.0027 +0.0026 12589 12622 12589 12622 +BM_copy/1024 +0.0035 +0.0028 25169 25257 25169 25239 +BM_copy/8192 +0.0191 +0.0194 201165 205013 201112 205010 +``` + +What it does is for the every benchmark from the first run it looks for the benchmark with exactly the same name in the second run, and then compares the results. If the names differ, the benchmark is omitted from the diff. +As you can note, the values in `Time` and `CPU` columns are calculated as `(new - old) / |old|`. + +2. Compare two different filters of one benchmark +The program is invoked like: + +``` bash +$ compare.py filters [benchmark options]... +``` +Where `` either specify a benchmark executable file, or a JSON output file. The type of the input file is automatically detected. If a benchmark executable is specified then the benchmark is run to obtain the results. Otherwise the results are simply loaded from the output file. + +Where `` and `` are the same regex filters that you would pass to the `[--benchmark_filter=]` parameter of the benchmark binary. + +`[benchmark options]` will be passed to the benchmarks invocations. They can be anything that binary accepts, be it either normal `--benchmark_*` parameters, or some custom parameters your binary takes. + +Example output: +``` +$ ./compare.py filters ./a.out BM_memcpy BM_copy +RUNNING: ./a.out --benchmark_filter=BM_memcpy --benchmark_out=/tmp/tmpBWKk0k +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:37:28 +------------------------------------------------------ +Benchmark Time CPU Iterations +------------------------------------------------------ +BM_memcpy/8 36 ns 36 ns 17891491 211.215MB/s +BM_memcpy/64 74 ns 74 ns 9400999 825.646MB/s +BM_memcpy/512 87 ns 87 ns 8027453 5.46126GB/s +BM_memcpy/1024 111 ns 111 ns 6116853 8.5648GB/s +BM_memcpy/8192 657 ns 656 ns 1064679 11.6247GB/s +RUNNING: ./a.out --benchmark_filter=BM_copy --benchmark_out=/tmp/tmpAvWcOM +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:37:33 +---------------------------------------------------- +Benchmark Time CPU Iterations +---------------------------------------------------- +BM_copy/8 227 ns 227 ns 3038700 33.6264MB/s +BM_copy/64 1640 ns 1640 ns 426893 37.2154MB/s +BM_copy/512 12804 ns 12801 ns 55417 38.1444MB/s +BM_copy/1024 25409 ns 25407 ns 27516 38.4365MB/s +BM_copy/8192 202986 ns 202990 ns 3454 38.4871MB/s +Comparing BM_memcpy to BM_copy (from ./a.out) +Benchmark Time CPU Time Old Time New CPU Old CPU New +-------------------------------------------------------------------------------------------------------------------- +[BM_memcpy vs. BM_copy]/8 +5.2829 +5.2812 36 227 36 227 +[BM_memcpy vs. BM_copy]/64 +21.1719 +21.1856 74 1640 74 1640 +[BM_memcpy vs. BM_copy]/512 +145.6487 +145.6097 87 12804 87 12801 +[BM_memcpy vs. BM_copy]/1024 +227.1860 +227.1776 111 25409 111 25407 +[BM_memcpy vs. BM_copy]/8192 +308.1664 +308.2898 657 202986 656 202990 +``` + +As you can see, it applies filter to the benchmarks, both when running the benchmark, and before doing the diff. And to make the diff work, the matches are replaced with some common string. Thus, you can compare two different benchmark families within one benchmark binary. +As you can note, the values in `Time` and `CPU` columns are calculated as `(new - old) / |old|`. + +3. Compare filter one from benchmark one to filter two from benchmark two: +The program is invoked like: + +``` bash +$ compare.py filters [benchmark options]... +``` + +Where `` and `` either specify a benchmark executable file, or a JSON output file. The type of the input file is automatically detected. If a benchmark executable is specified then the benchmark is run to obtain the results. Otherwise the results are simply loaded from the output file. + +Where `` and `` are the same regex filters that you would pass to the `[--benchmark_filter=]` parameter of the benchmark binary. + +`[benchmark options]` will be passed to the benchmarks invocations. They can be anything that binary accepts, be it either normal `--benchmark_*` parameters, or some custom parameters your binary takes. + +Example output: +``` +$ ./compare.py benchmarksfiltered ./a.out BM_memcpy ./a.out BM_copy +RUNNING: ./a.out --benchmark_filter=BM_memcpy --benchmark_out=/tmp/tmp_FvbYg +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:38:27 +------------------------------------------------------ +Benchmark Time CPU Iterations +------------------------------------------------------ +BM_memcpy/8 37 ns 37 ns 18953482 204.118MB/s +BM_memcpy/64 74 ns 74 ns 9206578 828.245MB/s +BM_memcpy/512 91 ns 91 ns 8086195 5.25476GB/s +BM_memcpy/1024 120 ns 120 ns 5804513 7.95662GB/s +BM_memcpy/8192 664 ns 664 ns 1028363 11.4948GB/s +RUNNING: ./a.out --benchmark_filter=BM_copy --benchmark_out=/tmp/tmpDfL5iE +Run on (8 X 4000 MHz CPU s) +2017-11-07 21:38:32 +---------------------------------------------------- +Benchmark Time CPU Iterations +---------------------------------------------------- +BM_copy/8 230 ns 230 ns 2985909 33.1161MB/s +BM_copy/64 1654 ns 1653 ns 419408 36.9137MB/s +BM_copy/512 13122 ns 13120 ns 53403 37.2156MB/s +BM_copy/1024 26679 ns 26666 ns 26575 36.6218MB/s +BM_copy/8192 215068 ns 215053 ns 3221 36.3283MB/s +Comparing BM_memcpy (from ./a.out) to BM_copy (from ./a.out) +Benchmark Time CPU Time Old Time New CPU Old CPU New +-------------------------------------------------------------------------------------------------------------------- +[BM_memcpy vs. BM_copy]/8 +5.1649 +5.1637 37 230 37 230 +[BM_memcpy vs. BM_copy]/64 +21.4352 +21.4374 74 1654 74 1653 +[BM_memcpy vs. BM_copy]/512 +143.6022 +143.5865 91 13122 91 13120 +[BM_memcpy vs. BM_copy]/1024 +221.5903 +221.4790 120 26679 120 26666 +[BM_memcpy vs. BM_copy]/8192 +322.9059 +323.0096 664 215068 664 215053 +``` +This is a mix of the previous two modes, two (potentially different) benchmark binaries are run, and a different filter is applied to each one. +As you can note, the values in `Time` and `CPU` columns are calculated as `(new - old) / |old|`. + +### Note: Interpreting the output + +Performance measurements are an art, and performance comparisons are doubly so. +Results are often noisy and don't necessarily have large absolute differences to +them, so just by visual inspection, it is not at all apparent if two +measurements are actually showing a performance change or not. It is even more +confusing with multiple benchmark repetitions. + +Thankfully, what we can do, is use statistical tests on the results to determine +whether the performance has statistically-significantly changed. `compare.py` +uses [Mann–Whitney U +test](https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test), with a null +hypothesis being that there's no difference in performance. + +**The below output is a summary of a benchmark comparison with statistics +provided for a multi-threaded process.** +``` +Benchmark Time CPU Time Old Time New CPU Old CPU New +----------------------------------------------------------------------------------------------------------------------------- +benchmark/threads:1/process_time/real_time_pvalue 0.0000 0.0000 U Test, Repetitions: 27 vs 27 +benchmark/threads:1/process_time/real_time_mean -0.1442 -0.1442 90 77 90 77 +benchmark/threads:1/process_time/real_time_median -0.1444 -0.1444 90 77 90 77 +benchmark/threads:1/process_time/real_time_stddev +0.3974 +0.3933 0 0 0 0 +benchmark/threads:1/process_time/real_time_cv +0.6329 +0.6280 0 0 0 0 +OVERALL_GEOMEAN -0.1442 -0.1442 0 0 0 0 +``` +-------------------------------------------- +Here's a breakdown of each row: + +**benchmark/threads:1/process_time/real_time_pvalue**: This shows the _p-value_ for +the statistical test comparing the performance of the process running with one +thread. A value of 0.0000 suggests a statistically significant difference in +performance. The comparison was conducted using the U Test (Mann-Whitney +U Test) with 27 repetitions for each case. + +**benchmark/threads:1/process_time/real_time_mean**: This shows the relative +difference in mean execution time between two different cases. The negative +value (-0.1442) implies that the new process is faster by about 14.42%. The old +time was 90 units, while the new time is 77 units. + +**benchmark/threads:1/process_time/real_time_median**: Similarly, this shows the +relative difference in the median execution time. Again, the new process is +faster by 14.44%. + +**benchmark/threads:1/process_time/real_time_stddev**: This is the relative +difference in the standard deviation of the execution time, which is a measure +of how much variation or dispersion there is from the mean. A positive value +(+0.3974) implies there is more variance in the execution time in the new +process. + +**benchmark/threads:1/process_time/real_time_cv**: CV stands for Coefficient of +Variation. It is the ratio of the standard deviation to the mean. It provides a +standardized measure of dispersion. An increase (+0.6329) indicates more +relative variability in the new process. + +**OVERALL_GEOMEAN**: Geomean stands for geometric mean, a type of average that is +less influenced by outliers. The negative value indicates a general improvement +in the new process. However, given the values are all zero for the old and new +times, this seems to be a mistake or placeholder in the output. + +----------------------------------------- + + + +Let's first try to see what the different columns represent in the above +`compare.py` benchmarking output: + + 1. **Benchmark:** The name of the function being benchmarked, along with the + size of the input (after the slash). + + 2. **Time:** The average time per operation, across all iterations. + + 3. **CPU:** The average CPU time per operation, across all iterations. + + 4. **Iterations:** The number of iterations the benchmark was run to get a + stable estimate. + + 5. **Time Old and Time New:** These represent the average time it takes for a + function to run in two different scenarios or versions. For example, you + might be comparing how fast a function runs before and after you make some + changes to it. + + 6. **CPU Old and CPU New:** These show the average amount of CPU time that the + function uses in two different scenarios or versions. This is similar to + Time Old and Time New, but focuses on CPU usage instead of overall time. + +In the comparison section, the relative differences in both time and CPU time +are displayed for each input size. + + +A statistically-significant difference is determined by a **p-value**, which is +a measure of the probability that the observed difference could have occurred +just by random chance. A smaller p-value indicates stronger evidence against the +null hypothesis. + +**Therefore:** + 1. If the p-value is less than the chosen significance level (alpha), we + reject the null hypothesis and conclude the benchmarks are significantly + different. + 2. If the p-value is greater than or equal to alpha, we fail to reject the + null hypothesis and treat the two benchmarks as similar. + + + +The result of said the statistical test is additionally communicated through color coding: +```diff ++ Green: +``` + The benchmarks are _**statistically different**_. This could mean the + performance has either **significantly improved** or **significantly + deteriorated**. You should look at the actual performance numbers to see which + is the case. +```diff +- Red: +``` + The benchmarks are _**statistically similar**_. This means the performance + **hasn't significantly changed**. + +In statistical terms, **'green'** means we reject the null hypothesis that +there's no difference in performance, and **'red'** means we fail to reject the +null hypothesis. This might seem counter-intuitive if you're expecting 'green' +to mean 'improved performance' and 'red' to mean 'worsened performance'. +```bash + But remember, in this context: + + 'Success' means 'successfully finding a difference'. + 'Failure' means 'failing to find a difference'. +``` + + +Also, please note that **even if** we determine that there **is** a +statistically-significant difference between the two measurements, it does not +_necessarily_ mean that the actual benchmarks that were measured **are** +different, or vice versa, even if we determine that there is **no** +statistically-significant difference between the two measurements, it does not +necessarily mean that the actual benchmarks that were measured **are not** +different. + + + +### U test + +If there is a sufficient repetition count of the benchmarks, the tool can do +a [U Test](https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test), of the +null hypothesis that it is equally likely that a randomly selected value from +one sample will be less than or greater than a randomly selected value from a +second sample. + +If the calculated p-value is below this value is lower than the significance +level alpha, then the result is said to be statistically significant and the +null hypothesis is rejected. Which in other words means that the two benchmarks +aren't identical. + +**WARNING**: requires **LARGE** (no less than 9) number of repetitions to be +meaningful! diff --git a/third_party/benchmark/docs/user_guide.md b/third_party/benchmark/docs/user_guide.md new file mode 100644 index 0000000..3152762 --- /dev/null +++ b/third_party/benchmark/docs/user_guide.md @@ -0,0 +1,1321 @@ +# User Guide + +## Command Line + +[Output Formats](#output-formats) + +[Output Files](#output-files) + +[Running Benchmarks](#running-benchmarks) + +[Running a Subset of Benchmarks](#running-a-subset-of-benchmarks) + +[Result Comparison](#result-comparison) + +[Extra Context](#extra-context) + +## Library + +[Runtime and Reporting Considerations](#runtime-and-reporting-considerations) + +[Setup/Teardown](#setupteardown) + +[Passing Arguments](#passing-arguments) + +[Custom Benchmark Name](#custom-benchmark-name) + +[Calculating Asymptotic Complexity](#asymptotic-complexity) + +[Templated Benchmarks](#templated-benchmarks) + +[Templated Benchmarks that take arguments](#templated-benchmarks-with-arguments) + +[Fixtures](#fixtures) + +[Custom Counters](#custom-counters) + +[Multithreaded Benchmarks](#multithreaded-benchmarks) + +[CPU Timers](#cpu-timers) + +[Manual Timing](#manual-timing) + +[Setting the Time Unit](#setting-the-time-unit) + +[Random Interleaving](random_interleaving.md) + +[User-Requested Performance Counters](perf_counters.md) + +[Preventing Optimization](#preventing-optimization) + +[Reporting Statistics](#reporting-statistics) + +[Custom Statistics](#custom-statistics) + +[Memory Usage](#memory-usage) + +[Using RegisterBenchmark](#using-register-benchmark) + +[Exiting with an Error](#exiting-with-an-error) + +[A Faster `KeepRunning` Loop](#a-faster-keep-running-loop) + +## Benchmarking Tips + +[Disabling CPU Frequency Scaling](#disabling-cpu-frequency-scaling) + +[Reducing Variance in Benchmarks](reducing_variance.md) + + + +## Output Formats + +The library supports multiple output formats. Use the +`--benchmark_format=` flag (or set the +`BENCHMARK_FORMAT=` environment variable) to set +the format type. `console` is the default format. + +The Console format is intended to be a human readable format. By default +the format generates color output. Context is output on stderr and the +tabular data on stdout. Example tabular output looks like: + +``` +Benchmark Time(ns) CPU(ns) Iterations +---------------------------------------------------------------------- +BM_SetInsert/1024/1 28928 29349 23853 133.097kiB/s 33.2742k items/s +BM_SetInsert/1024/8 32065 32913 21375 949.487kiB/s 237.372k items/s +BM_SetInsert/1024/10 33157 33648 21431 1.13369MiB/s 290.225k items/s +``` + +The JSON format outputs human readable json split into two top level attributes. +The `context` attribute contains information about the run in general, including +information about the CPU and the date. +The `benchmarks` attribute contains a list of every benchmark run. Example json +output looks like: + +```json +{ + "context": { + "date": "2015/03/17-18:40:25", + "num_cpus": 40, + "mhz_per_cpu": 2801, + "cpu_scaling_enabled": false, + "build_type": "debug" + }, + "benchmarks": [ + { + "name": "BM_SetInsert/1024/1", + "iterations": 94877, + "real_time": 29275, + "cpu_time": 29836, + "bytes_per_second": 134066, + "items_per_second": 33516 + }, + { + "name": "BM_SetInsert/1024/8", + "iterations": 21609, + "real_time": 32317, + "cpu_time": 32429, + "bytes_per_second": 986770, + "items_per_second": 246693 + }, + { + "name": "BM_SetInsert/1024/10", + "iterations": 21393, + "real_time": 32724, + "cpu_time": 33355, + "bytes_per_second": 1199226, + "items_per_second": 299807 + } + ] +} +``` + +The CSV format outputs comma-separated values. The `context` is output on stderr +and the CSV itself on stdout. Example CSV output looks like: + +``` +name,iterations,real_time,cpu_time,bytes_per_second,items_per_second,label +"BM_SetInsert/1024/1",65465,17890.7,8407.45,475768,118942, +"BM_SetInsert/1024/8",116606,18810.1,9766.64,3.27646e+06,819115, +"BM_SetInsert/1024/10",106365,17238.4,8421.53,4.74973e+06,1.18743e+06, +``` + + + +## Output Files + +Write benchmark results to a file with the `--benchmark_out=` option +(or set `BENCHMARK_OUT`). Specify the output format with +`--benchmark_out_format={json|console|csv}` (or set +`BENCHMARK_OUT_FORMAT={json|console|csv}`). Note that the 'csv' reporter is +deprecated and the saved `.csv` file +[is not parsable](https://github.com/google/benchmark/issues/794) by csv +parsers. + +Specifying `--benchmark_out` does not suppress the console output. + + + +## Running Benchmarks + +Benchmarks are executed by running the produced binaries. Benchmarks binaries, +by default, accept options that may be specified either through their command +line interface or by setting environment variables before execution. For every +`--option_flag=` CLI switch, a corresponding environment variable +`OPTION_FLAG=` exist and is used as default if set (CLI switches always + prevails). A complete list of CLI options is available running benchmarks + with the `--help` switch. + +### Dry runs + +To confirm that benchmarks can run successfully without needing to wait for +multiple repetitions and iterations, the `--benchmark_dry_run` flag can be +used. This will run the benchmarks as normal, but for 1 iteration and 1 +repetition only. + + + +## Running a Subset of Benchmarks + +The `--benchmark_filter=` option (or `BENCHMARK_FILTER=` +environment variable) can be used to only run the benchmarks that match +the specified ``. For example: + +```bash +$ ./run_benchmarks.x --benchmark_filter=BM_memcpy/32 +Run on (1 X 2300 MHz CPU ) +2016-06-25 19:34:24 +Benchmark Time CPU Iterations +---------------------------------------------------- +BM_memcpy/32 11 ns 11 ns 79545455 +BM_memcpy/32k 2181 ns 2185 ns 324074 +BM_memcpy/32 12 ns 12 ns 54687500 +BM_memcpy/32k 1834 ns 1837 ns 357143 +``` + +## Disabling Benchmarks + +It is possible to temporarily disable benchmarks by renaming the benchmark +function to have the prefix "DISABLED_". This will cause the benchmark to +be skipped at runtime. + + + +## Result comparison + +It is possible to compare the benchmarking results. +See [Additional Tooling Documentation](tools.md) + + + +## Extra Context + +Sometimes it's useful to add extra context to the content printed before the +results. By default this section includes information about the CPU on which +the benchmarks are running. If you do want to add more context, you can use +the `benchmark_context` command line flag: + +```bash +$ ./run_benchmarks --benchmark_context=pwd=`pwd` +Run on (1 x 2300 MHz CPU) +pwd: /home/user/benchmark/ +Benchmark Time CPU Iterations +---------------------------------------------------- +BM_memcpy/32 11 ns 11 ns 79545455 +BM_memcpy/32k 2181 ns 2185 ns 324074 +``` + +You can get the same effect with the API: + +```c++ + benchmark::AddCustomContext("foo", "bar"); +``` + +Note that attempts to add a second value with the same key will fail with an +error message. + + + +## Runtime and Reporting Considerations + +When the benchmark binary is executed, each benchmark function is run serially. +The number of iterations to run is determined dynamically by running the +benchmark a few times and measuring the time taken and ensuring that the +ultimate result will be statistically stable. As such, faster benchmark +functions will be run for more iterations than slower benchmark functions, and +the number of iterations is thus reported. + +In all cases, the number of iterations for which the benchmark is run is +governed by the amount of time the benchmark takes. Concretely, the number of +iterations is at least one, not more than 1e9, until CPU time is greater than +the minimum time, or the wallclock time is 5x minimum time. The minimum time is +set per benchmark by calling `MinTime` on the registered benchmark object. + +Furthermore warming up a benchmark might be necessary in order to get +stable results because of e.g caching effects of the code under benchmark. +Warming up means running the benchmark a given amount of time, before +results are actually taken into account. The amount of time for which +the warmup should be run can be set per benchmark by calling +`MinWarmUpTime` on the registered benchmark object or for all benchmarks +using the `--benchmark_min_warmup_time` command-line option. Note that +`MinWarmUpTime` will overwrite the value of `--benchmark_min_warmup_time` +for the single benchmark. How many iterations the warmup run of each +benchmark takes is determined the same way as described in the paragraph +above. Per default the warmup phase is set to 0 seconds and is therefore +disabled. + +Average timings are then reported over the iterations run. If multiple +repetitions are requested using the `--benchmark_repetitions` command-line +option, or at registration time, the benchmark function will be run several +times and statistical results across these repetitions will also be reported. + +As well as the per-benchmark entries, a preamble in the report will include +information about the machine on which the benchmarks are run. + + + +## Setup/Teardown + +Global setup/teardown specific to each benchmark can be done by +passing a callback to Setup/Teardown: + +The setup/teardown callbacks will be invoked once for each benchmark. If the +benchmark is multi-threaded (will run in k threads), they will be invoked +exactly once before each run with k threads. + +If the benchmark uses different size groups of threads, the above will be true +for each size group. + +Eg., + +```c++ +static void DoSetup(const benchmark::State& state) { +} + +static void DoTeardown(const benchmark::State& state) { +} + +static void BM_func(benchmark::State& state) {...} + +BENCHMARK(BM_func)->Arg(1)->Arg(3)->Threads(16)->Threads(32)->Setup(DoSetup)->Teardown(DoTeardown); + +``` + +In this example, `DoSetup` and `DoTearDown` will be invoked 4 times each, +specifically, once for each of this family: + - BM_func_Arg_1_Threads_16, BM_func_Arg_1_Threads_32 + - BM_func_Arg_3_Threads_16, BM_func_Arg_3_Threads_32 + + + +## Passing Arguments + +Sometimes a family of benchmarks can be implemented with just one routine that +takes an extra argument to specify which one of the family of benchmarks to +run. For example, the following code defines a family of benchmarks for +measuring the speed of `memcpy()` calls of different lengths: + +```c++ +static void BM_memcpy(benchmark::State& state) { + char* src = new char[state.range(0)]; + char* dst = new char[state.range(0)]; + memset(src, 'x', state.range(0)); + for (auto _ : state) + memcpy(dst, src, state.range(0)); + state.SetBytesProcessed(int64_t(state.iterations()) * + int64_t(state.range(0))); + delete[] src; + delete[] dst; +} +BENCHMARK(BM_memcpy)->Arg(8)->Arg(64)->Arg(512)->Arg(4<<10)->Arg(8<<10); +``` + +The preceding code is quite repetitive, and can be replaced with the following +short-hand. The following invocation will pick a few appropriate arguments in +the specified range and will generate a benchmark for each such argument. + +```c++ +BENCHMARK(BM_memcpy)->Range(8, 8<<10); +``` + +By default the arguments in the range are generated in multiples of eight and +the command above selects [ 8, 64, 512, 4k, 8k ]. In the following code the +range multiplier is changed to multiples of two. + +```c++ +BENCHMARK(BM_memcpy)->RangeMultiplier(2)->Range(8, 8<<10); +``` + +Now arguments generated are [ 8, 16, 32, 64, 128, 256, 512, 1024, 2k, 4k, 8k ]. + +The preceding code shows a method of defining a sparse range. The following +example shows a method of defining a dense range. It is then used to benchmark +the performance of `std::vector` initialization for uniformly increasing sizes. + +```c++ +static void BM_DenseRange(benchmark::State& state) { + for(auto _ : state) { + std::vector v(state.range(0), state.range(0)); + auto data = v.data(); + benchmark::DoNotOptimize(data); + benchmark::ClobberMemory(); + } +} +BENCHMARK(BM_DenseRange)->DenseRange(0, 1024, 128); +``` + +Now arguments generated are [ 0, 128, 256, 384, 512, 640, 768, 896, 1024 ]. + +You might have a benchmark that depends on two or more inputs. For example, the +following code defines a family of benchmarks for measuring the speed of set +insertion. + +```c++ +static void BM_SetInsert(benchmark::State& state) { + std::set data; + for (auto _ : state) { + state.PauseTiming(); + data = ConstructRandomSet(state.range(0)); + state.ResumeTiming(); + for (int j = 0; j < state.range(1); ++j) + data.insert(RandomNumber()); + } +} +BENCHMARK(BM_SetInsert) + ->Args({1<<10, 128}) + ->Args({2<<10, 128}) + ->Args({4<<10, 128}) + ->Args({8<<10, 128}) + ->Args({1<<10, 512}) + ->Args({2<<10, 512}) + ->Args({4<<10, 512}) + ->Args({8<<10, 512}); +``` + +The preceding code is quite repetitive, and can be replaced with the following +short-hand. The following macro will pick a few appropriate arguments in the +product of the two specified ranges and will generate a benchmark for each such +pair. + + +```c++ +BENCHMARK(BM_SetInsert)->Ranges({{1<<10, 8<<10}, {128, 512}}); +``` + + +Some benchmarks may require specific argument values that cannot be expressed +with `Ranges`. In this case, `ArgsProduct` offers the ability to generate a +benchmark input for each combination in the product of the supplied vectors. + + +```c++ +BENCHMARK(BM_SetInsert) + ->ArgsProduct({{1<<10, 3<<10, 8<<10}, {20, 40, 60, 80}}) +// would generate the same benchmark arguments as +BENCHMARK(BM_SetInsert) + ->Args({1<<10, 20}) + ->Args({3<<10, 20}) + ->Args({8<<10, 20}) + ->Args({3<<10, 40}) + ->Args({8<<10, 40}) + ->Args({1<<10, 40}) + ->Args({1<<10, 60}) + ->Args({3<<10, 60}) + ->Args({8<<10, 60}) + ->Args({1<<10, 80}) + ->Args({3<<10, 80}) + ->Args({8<<10, 80}); +``` + + +For the most common scenarios, helper methods for creating a list of +integers for a given sparse or dense range are provided. + +```c++ +BENCHMARK(BM_SetInsert) + ->ArgsProduct({ + benchmark::CreateRange(8, 128, /*multi=*/2), + benchmark::CreateDenseRange(1, 4, /*step=*/1) + }) +// would generate the same benchmark arguments as +BENCHMARK(BM_SetInsert) + ->ArgsProduct({ + {8, 16, 32, 64, 128}, + {1, 2, 3, 4} + }); +``` + +For more complex patterns of inputs, passing a custom function to `Apply` allows +programmatic specification of an arbitrary set of arguments on which to run the +benchmark. The following example enumerates a dense range on one parameter, +and a sparse range on the second. + +```c++ +static void CustomArguments(benchmark::internal::Benchmark* b) { + for (int i = 0; i <= 10; ++i) + for (int j = 32; j <= 1024*1024; j *= 8) + b->Args({i, j}); +} +BENCHMARK(BM_SetInsert)->Apply(CustomArguments); +``` + +### Passing Arbitrary Arguments to a Benchmark + +In C++11 it is possible to define a benchmark that takes an arbitrary number +of extra arguments. The `BENCHMARK_CAPTURE(func, test_case_name, ...args)` +macro creates a benchmark that invokes `func` with the `benchmark::State` as +the first argument followed by the specified `args...`. +The `test_case_name` is appended to the name of the benchmark and +should describe the values passed. + +```c++ +template +void BM_takes_args(benchmark::State& state, Args&&... args) { + auto args_tuple = std::make_tuple(std::move(args)...); + for (auto _ : state) { + std::cout << std::get<0>(args_tuple) << ": " << std::get<1>(args_tuple) + << '\n'; + [...] + } +} +// Registers a benchmark named "BM_takes_args/int_string_test" that passes +// the specified values to `args`. +BENCHMARK_CAPTURE(BM_takes_args, int_string_test, 42, std::string("abc")); + +// Registers the same benchmark "BM_takes_args/int_test" that passes +// the specified values to `args`. +BENCHMARK_CAPTURE(BM_takes_args, int_test, 42, 43); +``` + +Note that elements of `...args` may refer to global variables. Users should +avoid modifying global state inside of a benchmark. + + + +## Calculating Asymptotic Complexity (Big O) + +Asymptotic complexity might be calculated for a family of benchmarks. The +following code will calculate the coefficient for the high-order term in the +running time and the normalized root-mean square error of string comparison. + +```c++ +static void BM_StringCompare(benchmark::State& state) { + std::string s1(state.range(0), '-'); + std::string s2(state.range(0), '-'); + for (auto _ : state) { + auto comparison_result = s1.compare(s2); + benchmark::DoNotOptimize(comparison_result); + } + state.SetComplexityN(state.range(0)); +} +BENCHMARK(BM_StringCompare) + ->RangeMultiplier(2)->Range(1<<10, 1<<18)->Complexity(benchmark::oN); +``` + +As shown in the following invocation, asymptotic complexity might also be +calculated automatically. + +```c++ +BENCHMARK(BM_StringCompare) + ->RangeMultiplier(2)->Range(1<<10, 1<<18)->Complexity(); +``` + +The following code will specify asymptotic complexity with a lambda function, +that might be used to customize high-order term calculation. + +```c++ +BENCHMARK(BM_StringCompare)->RangeMultiplier(2) + ->Range(1<<10, 1<<18)->Complexity([](benchmark::IterationCount n)->double{return n; }); +``` + + + +## Custom Benchmark Name + +You can change the benchmark's name as follows: + +```c++ +BENCHMARK(BM_memcpy)->Name("memcpy")->RangeMultiplier(2)->Range(8, 8<<10); +``` + +The invocation will execute the benchmark as before using `BM_memcpy` but changes +the prefix in the report to `memcpy`. + + + +## Templated Benchmarks + +This example produces and consumes messages of size `sizeof(v)` `range_x` +times. It also outputs throughput in the absence of multiprogramming. + +```c++ +template void BM_Sequential(benchmark::State& state) { + Q q; + typename Q::value_type v; + for (auto _ : state) { + for (int i = state.range(0); i--; ) + q.push(v); + for (int e = state.range(0); e--; ) + q.Wait(&v); + } + // actually messages, not bytes: + state.SetBytesProcessed( + static_cast(state.iterations())*state.range(0)); +} +// C++03 +BENCHMARK_TEMPLATE(BM_Sequential, WaitQueue)->Range(1<<0, 1<<10); + +// C++11 or newer, you can use the BENCHMARK macro with template parameters: +BENCHMARK(BM_Sequential>)->Range(1<<0, 1<<10); + +``` + +Three macros are provided for adding benchmark templates. + +```c++ +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK(func<...>) // Takes any number of parameters. +#else // C++ < C++11 +#define BENCHMARK_TEMPLATE(func, arg1) +#endif +#define BENCHMARK_TEMPLATE1(func, arg1) +#define BENCHMARK_TEMPLATE2(func, arg1, arg2) +``` + + + +## Templated Benchmarks that take arguments + +Sometimes there is a need to template benchmarks, and provide arguments to them. + +```c++ +template void BM_Sequential_With_Step(benchmark::State& state, int step) { + Q q; + typename Q::value_type v; + for (auto _ : state) { + for (int i = state.range(0); i-=step; ) + q.push(v); + for (int e = state.range(0); e-=step; ) + q.Wait(&v); + } + // actually messages, not bytes: + state.SetBytesProcessed( + static_cast(state.iterations())*state.range(0)); +} + +BENCHMARK_TEMPLATE1_CAPTURE(BM_Sequential, WaitQueue, Step1, 1)->Range(1<<0, 1<<10); +``` + + + +## Fixtures + +Fixture tests are created by first defining a type that derives from +`::benchmark::Fixture` and then creating/registering the tests using the +following macros: + +* `BENCHMARK_F(ClassName, Method)` +* `BENCHMARK_DEFINE_F(ClassName, Method)` +* `BENCHMARK_REGISTER_F(ClassName, Method)` + +For Example: + +```c++ +class MyFixture : public benchmark::Fixture { +public: + void SetUp(::benchmark::State& state) { + } + + void TearDown(::benchmark::State& state) { + } +}; + +// Defines and registers `FooTest` using the class `MyFixture`. +BENCHMARK_F(MyFixture, FooTest)(benchmark::State& st) { + for (auto _ : st) { + ... + } +} + +// Only defines `BarTest` using the class `MyFixture`. +BENCHMARK_DEFINE_F(MyFixture, BarTest)(benchmark::State& st) { + for (auto _ : st) { + ... + } +} +// `BarTest` is NOT registered. +BENCHMARK_REGISTER_F(MyFixture, BarTest)->Threads(2); +// `BarTest` is now registered. +``` + +### Templated Fixtures + +Also you can create templated fixture by using the following macros: + +* `BENCHMARK_TEMPLATE_F(ClassName, Method, ...)` +* `BENCHMARK_TEMPLATE_DEFINE_F(ClassName, Method, ...)` + +For example: + +```c++ +template +class MyFixture : public benchmark::Fixture {}; + +// Defines and registers `IntTest` using the class template `MyFixture`. +BENCHMARK_TEMPLATE_F(MyFixture, IntTest, int)(benchmark::State& st) { + for (auto _ : st) { + ... + } +} + +// Only defines `DoubleTest` using the class template `MyFixture`. +BENCHMARK_TEMPLATE_DEFINE_F(MyFixture, DoubleTest, double)(benchmark::State& st) { + for (auto _ : st) { + ... + } +} +// `DoubleTest` is NOT registered. +BENCHMARK_REGISTER_F(MyFixture, DoubleTest)->Threads(2); +// `DoubleTest` is now registered. +``` + + + +## Custom Counters + +You can add your own counters with user-defined names. The example below +will add columns "Foo", "Bar" and "Baz" in its output: + +```c++ +static void UserCountersExample1(benchmark::State& state) { + double numFoos = 0, numBars = 0, numBazs = 0; + for (auto _ : state) { + // ... count Foo,Bar,Baz events + } + state.counters["Foo"] = numFoos; + state.counters["Bar"] = numBars; + state.counters["Baz"] = numBazs; +} +``` + +The `state.counters` object is a `std::map` with `std::string` keys +and `Counter` values. The latter is a `double`-like class, via an implicit +conversion to `double&`. Thus you can use all of the standard arithmetic +assignment operators (`=,+=,-=,*=,/=`) to change the value of each counter. + +In multithreaded benchmarks, each counter is set on the calling thread only. +When the benchmark finishes, the counters from each thread will be summed; +the resulting sum is the value which will be shown for the benchmark. + +The `Counter` constructor accepts three parameters: the value as a `double` +; a bit flag which allows you to show counters as rates, and/or as per-thread +iteration, and/or as per-thread averages, and/or iteration invariants, +and/or finally inverting the result; and a flag specifying the 'unit' - i.e. +is 1k a 1000 (default, `benchmark::Counter::OneK::kIs1000`), or 1024 +(`benchmark::Counter::OneK::kIs1024`)? + +```c++ + // sets a simple counter + state.counters["Foo"] = numFoos; + + // Set the counter as a rate. It will be presented divided + // by the duration of the benchmark. + // Meaning: per one second, how many 'foo's are processed? + state.counters["FooRate"] = Counter(numFoos, benchmark::Counter::kIsRate); + + // Set the counter as a rate. It will be presented divided + // by the duration of the benchmark, and the result inverted. + // Meaning: how many seconds it takes to process one 'foo'? + state.counters["FooInvRate"] = Counter(numFoos, benchmark::Counter::kIsRate | benchmark::Counter::kInvert); + + // Set the counter as a thread-average quantity. It will + // be presented divided by the number of threads. + state.counters["FooAvg"] = Counter(numFoos, benchmark::Counter::kAvgThreads); + + // There's also a combined flag: + state.counters["FooAvgRate"] = Counter(numFoos,benchmark::Counter::kAvgThreadsRate); + + // This says that we process with the rate of state.range(0) bytes every iteration: + state.counters["BytesProcessed"] = Counter(state.range(0), benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1024); +``` + +When you're compiling in C++11 mode or later you can use `insert()` with +`std::initializer_list`: + + +```c++ + // With C++11, this can be done: + state.counters.insert({{"Foo", numFoos}, {"Bar", numBars}, {"Baz", numBazs}}); + // ... instead of: + state.counters["Foo"] = numFoos; + state.counters["Bar"] = numBars; + state.counters["Baz"] = numBazs; +``` + + +### Counter Reporting + +When using the console reporter, by default, user counters are printed at +the end after the table, the same way as ``bytes_processed`` and +``items_processed``. This is best for cases in which there are few counters, +or where there are only a couple of lines per benchmark. Here's an example of +the default output: + +``` +------------------------------------------------------------------------------ +Benchmark Time CPU Iterations UserCounters... +------------------------------------------------------------------------------ +BM_UserCounter/threads:8 2248 ns 10277 ns 68808 Bar=16 Bat=40 Baz=24 Foo=8 +BM_UserCounter/threads:1 9797 ns 9788 ns 71523 Bar=2 Bat=5 Baz=3 Foo=1024m +BM_UserCounter/threads:2 4924 ns 9842 ns 71036 Bar=4 Bat=10 Baz=6 Foo=2 +BM_UserCounter/threads:4 2589 ns 10284 ns 68012 Bar=8 Bat=20 Baz=12 Foo=4 +BM_UserCounter/threads:8 2212 ns 10287 ns 68040 Bar=16 Bat=40 Baz=24 Foo=8 +BM_UserCounter/threads:16 1782 ns 10278 ns 68144 Bar=32 Bat=80 Baz=48 Foo=16 +BM_UserCounter/threads:32 1291 ns 10296 ns 68256 Bar=64 Bat=160 Baz=96 Foo=32 +BM_UserCounter/threads:4 2615 ns 10307 ns 68040 Bar=8 Bat=20 Baz=12 Foo=4 +BM_Factorial 26 ns 26 ns 26608979 40320 +BM_Factorial/real_time 26 ns 26 ns 26587936 40320 +BM_CalculatePiRange/1 16 ns 16 ns 45704255 0 +BM_CalculatePiRange/8 73 ns 73 ns 9520927 3.28374 +BM_CalculatePiRange/64 609 ns 609 ns 1140647 3.15746 +BM_CalculatePiRange/512 4900 ns 4901 ns 142696 3.14355 +``` + +If this doesn't suit you, you can print each counter as a table column by +passing the flag `--benchmark_counters_tabular=true` to the benchmark +application. This is best for cases in which there are a lot of counters, or +a lot of lines per individual benchmark. Note that this will trigger a +reprinting of the table header any time the counter set changes between +individual benchmarks. Here's an example of corresponding output when +`--benchmark_counters_tabular=true` is passed: + +``` +--------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations Bar Bat Baz Foo +--------------------------------------------------------------------------------------- +BM_UserCounter/threads:8 2198 ns 9953 ns 70688 16 40 24 8 +BM_UserCounter/threads:1 9504 ns 9504 ns 73787 2 5 3 1 +BM_UserCounter/threads:2 4775 ns 9550 ns 72606 4 10 6 2 +BM_UserCounter/threads:4 2508 ns 9951 ns 70332 8 20 12 4 +BM_UserCounter/threads:8 2055 ns 9933 ns 70344 16 40 24 8 +BM_UserCounter/threads:16 1610 ns 9946 ns 70720 32 80 48 16 +BM_UserCounter/threads:32 1192 ns 9948 ns 70496 64 160 96 32 +BM_UserCounter/threads:4 2506 ns 9949 ns 70332 8 20 12 4 +-------------------------------------------------------------- +Benchmark Time CPU Iterations +-------------------------------------------------------------- +BM_Factorial 26 ns 26 ns 26392245 40320 +BM_Factorial/real_time 26 ns 26 ns 26494107 40320 +BM_CalculatePiRange/1 15 ns 15 ns 45571597 0 +BM_CalculatePiRange/8 74 ns 74 ns 9450212 3.28374 +BM_CalculatePiRange/64 595 ns 595 ns 1173901 3.15746 +BM_CalculatePiRange/512 4752 ns 4752 ns 147380 3.14355 +BM_CalculatePiRange/4k 37970 ns 37972 ns 18453 3.14184 +BM_CalculatePiRange/32k 303733 ns 303744 ns 2305 3.14162 +BM_CalculatePiRange/256k 2434095 ns 2434186 ns 288 3.1416 +BM_CalculatePiRange/1024k 9721140 ns 9721413 ns 71 3.14159 +BM_CalculatePi/threads:8 2255 ns 9943 ns 70936 +``` + +Note above the additional header printed when the benchmark changes from +``BM_UserCounter`` to ``BM_Factorial``. This is because ``BM_Factorial`` does +not have the same counter set as ``BM_UserCounter``. + + + +## Multithreaded Benchmarks + +In a multithreaded test (benchmark invoked by multiple threads simultaneously), +it is guaranteed that none of the threads will start until all have reached +the start of the benchmark loop, and all will have finished before any thread +exits the benchmark loop. (This behavior is also provided by the `KeepRunning()` +API) As such, any global setup or teardown can be wrapped in a check against the thread +index: + +```c++ +static void BM_MultiThreaded(benchmark::State& state) { + if (state.thread_index() == 0) { + // Setup code here. + } + for (auto _ : state) { + // Run the test as normal. + } + if (state.thread_index() == 0) { + // Teardown code here. + } +} +BENCHMARK(BM_MultiThreaded)->Threads(2); +``` + +To run the benchmark across a range of thread counts, instead of `Threads`, use +`ThreadRange`. This takes two parameters (`min_threads` and `max_threads`) and +runs the benchmark once for values in the inclusive range. For example: + +```c++ +BENCHMARK(BM_MultiThreaded)->ThreadRange(1, 8); +``` + +will run `BM_MultiThreaded` with thread counts 1, 2, 4, and 8. + +If the benchmarked code itself uses threads and you want to compare it to +single-threaded code, you may want to use real-time ("wallclock") measurements +for latency comparisons: + +```c++ +BENCHMARK(BM_test)->Range(8, 8<<10)->UseRealTime(); +``` + +Without `UseRealTime`, CPU time is used by default. + + + +## CPU Timers + +By default, the CPU timer only measures the time spent by the main thread. +If the benchmark itself uses threads internally, this measurement may not +be what you are looking for. Instead, there is a way to measure the total +CPU usage of the process, by all the threads. + +```c++ +void callee(int i); + +static void MyMain(int size) { +#pragma omp parallel for + for(int i = 0; i < size; i++) + callee(i); +} + +static void BM_OpenMP(benchmark::State& state) { + for (auto _ : state) + MyMain(state.range(0)); +} + +// Measure the time spent by the main thread, use it to decide for how long to +// run the benchmark loop. Depending on the internal implementation detail may +// measure to anywhere from near-zero (the overhead spent before/after work +// handoff to worker thread[s]) to the whole single-thread time. +BENCHMARK(BM_OpenMP)->Range(8, 8<<10); + +// Measure the user-visible time, the wall clock (literally, the time that +// has passed on the clock on the wall), use it to decide for how long to +// run the benchmark loop. This will always be meaningful, and will match the +// time spent by the main thread in single-threaded case, in general decreasing +// with the number of internal threads doing the work. +BENCHMARK(BM_OpenMP)->Range(8, 8<<10)->UseRealTime(); + +// Measure the total CPU consumption, use it to decide for how long to +// run the benchmark loop. This will always measure to no less than the +// time spent by the main thread in single-threaded case. +BENCHMARK(BM_OpenMP)->Range(8, 8<<10)->MeasureProcessCPUTime(); + +// A mixture of the last two. Measure the total CPU consumption, but use the +// wall clock to decide for how long to run the benchmark loop. +BENCHMARK(BM_OpenMP)->Range(8, 8<<10)->MeasureProcessCPUTime()->UseRealTime(); +``` + +### Controlling Timers + +Normally, the entire duration of the work loop (`for (auto _ : state) {}`) +is measured. But sometimes, it is necessary to do some work inside of +that loop, every iteration, but without counting that time to the benchmark time. +That is possible, although it is not recommended, since it has high overhead. + + +```c++ +static void BM_SetInsert_With_Timer_Control(benchmark::State& state) { + std::set data; + for (auto _ : state) { + state.PauseTiming(); // Stop timers. They will not count until they are resumed. + data = ConstructRandomSet(state.range(0)); // Do something that should not be measured + state.ResumeTiming(); // And resume timers. They are now counting again. + // The rest will be measured. + for (int j = 0; j < state.range(1); ++j) + data.insert(RandomNumber()); + } +} +BENCHMARK(BM_SetInsert_With_Timer_Control)->Ranges({{1<<10, 8<<10}, {128, 512}}); +``` + + + + +## Manual Timing + +For benchmarking something for which neither CPU time nor real-time are +correct or accurate enough, completely manual timing is supported using +the `UseManualTime` function. + +When `UseManualTime` is used, the benchmarked code must call +`SetIterationTime` once per iteration of the benchmark loop to +report the manually measured time. + +An example use case for this is benchmarking GPU execution (e.g. OpenCL +or CUDA kernels, OpenGL or Vulkan or Direct3D draw calls), which cannot +be accurately measured using CPU time or real-time. Instead, they can be +measured accurately using a dedicated API, and these measurement results +can be reported back with `SetIterationTime`. + +```c++ +static void BM_ManualTiming(benchmark::State& state) { + int microseconds = state.range(0); + std::chrono::duration sleep_duration { + static_cast(microseconds) + }; + + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + // Simulate some useful workload with a sleep + std::this_thread::sleep_for(sleep_duration); + auto end = std::chrono::high_resolution_clock::now(); + + auto elapsed_seconds = + std::chrono::duration_cast>( + end - start); + + state.SetIterationTime(elapsed_seconds.count()); + } +} +BENCHMARK(BM_ManualTiming)->Range(1, 1<<17)->UseManualTime(); +``` + + + +## Setting the Time Unit + +If a benchmark runs a few milliseconds it may be hard to visually compare the +measured times, since the output data is given in nanoseconds per default. In +order to manually set the time unit, you can specify it manually: + +```c++ +BENCHMARK(BM_test)->Unit(benchmark::kMillisecond); +``` + +Additionally the default time unit can be set globally with the +`--benchmark_time_unit={ns|us|ms|s}` command line argument. The argument only +affects benchmarks where the time unit is not set explicitly. + + + +## Preventing Optimization + +To prevent a value or expression from being optimized away by the compiler +the `benchmark::DoNotOptimize(...)` and `benchmark::ClobberMemory()` +functions can be used. + +```c++ +static void BM_test(benchmark::State& state) { + for (auto _ : state) { + int x = 0; + for (int i=0; i < 64; ++i) { + benchmark::DoNotOptimize(x += i); + } + } +} +``` + +`DoNotOptimize()` forces the *result* of `` to be stored in either +memory or a register. For GNU based compilers it acts as read/write barrier +for global memory. More specifically it forces the compiler to flush pending +writes to memory and reload any other values as necessary. + +Note that `DoNotOptimize()` does not prevent optimizations on `` +in any way. `` may even be removed entirely when the result is already +known. For example: + +```c++ + // Example 1: `` is removed entirely. + int foo(int x) { return x + 42; } + while (...) DoNotOptimize(foo(0)); // Optimized to DoNotOptimize(42); + + // Example 2: Result of '' is only reused. + int bar(int) __attribute__((const)); + while (...) DoNotOptimize(bar(0)); // Optimized to: + // int __result__ = bar(0); + // while (...) DoNotOptimize(__result__); +``` + +The second tool for preventing optimizations is `ClobberMemory()`. In essence +`ClobberMemory()` forces the compiler to perform all pending writes to global +memory. Memory managed by block scope objects must be "escaped" using +`DoNotOptimize(...)` before it can be clobbered. In the below example +`ClobberMemory()` prevents the call to `v.push_back(42)` from being optimized +away. + +```c++ +static void BM_vector_push_back(benchmark::State& state) { + for (auto _ : state) { + std::vector v; + v.reserve(1); + auto data = v.data(); // Allow v.data() to be clobbered. Pass as non-const + benchmark::DoNotOptimize(data); // lvalue to avoid undesired compiler optimizations + v.push_back(42); + benchmark::ClobberMemory(); // Force 42 to be written to memory. + } +} +``` + +Note that `ClobberMemory()` is only available for GNU or MSVC based compilers. + + + +## Statistics: Reporting the Mean, Median and Standard Deviation / Coefficient of variation of Repeated Benchmarks + +By default each benchmark is run once and that single result is reported. +However benchmarks are often noisy and a single result may not be representative +of the overall behavior. For this reason it's possible to repeatedly rerun the +benchmark. + +The number of runs of each benchmark is specified globally by the +`--benchmark_repetitions` flag or on a per benchmark basis by calling +`Repetitions` on the registered benchmark object. When a benchmark is run more +than once the mean, median, standard deviation and coefficient of variation +of the runs will be reported. + +Additionally the `--benchmark_report_aggregates_only={true|false}`, +`--benchmark_display_aggregates_only={true|false}` flags or +`ReportAggregatesOnly(bool)`, `DisplayAggregatesOnly(bool)` functions can be +used to change how repeated tests are reported. By default the result of each +repeated run is reported. When `report aggregates only` option is `true`, +only the aggregates (i.e. mean, median, standard deviation and coefficient +of variation, maybe complexity measurements if they were requested) of the runs +is reported, to both the reporters - standard output (console), and the file. +However when only the `display aggregates only` option is `true`, +only the aggregates are displayed in the standard output, while the file +output still contains everything. +Calling `ReportAggregatesOnly(bool)` / `DisplayAggregatesOnly(bool)` on a +registered benchmark object overrides the value of the appropriate flag for that +benchmark. + + + +## Custom Statistics + +While having these aggregates is nice, this may not be enough for everyone. +For example you may want to know what the largest observation is, e.g. because +you have some real-time constraints. This is easy. The following code will +specify a custom statistic to be calculated, defined by a lambda function. + +```c++ +void BM_spin_empty(benchmark::State& state) { + for (auto _ : state) { + for (int x = 0; x < state.range(0); ++x) { + benchmark::DoNotOptimize(x); + } + } +} + +BENCHMARK(BM_spin_empty) + ->Repetitions(3) // or add option --benchmark_repetitions=3 + ->ComputeStatistics("max", [](const std::vector& v) -> double { + return *(std::max_element(std::begin(v), std::end(v))); + }) + ->Arg(512); +``` + +While usually the statistics produce values in time units, +you can also produce percentages: + +```c++ +void BM_spin_empty(benchmark::State& state) { + for (auto _ : state) { + for (int x = 0; x < state.range(0); ++x) { + benchmark::DoNotOptimize(x); + } + } +} + +BENCHMARK(BM_spin_empty) + ->Repetitions(3) // or add option --benchmark_repetitions=3 + ->ComputeStatistics("ratio", [](const std::vector& v) -> double { + return v.front() / v.back(); + }, benchmark::StatisticUnit::kPercentage) + ->Arg(512); +``` + + + +## Memory Usage + +It's often useful to also track memory usage for benchmarks, alongside CPU +performance. For this reason, benchmark offers the `RegisterMemoryManager` +method that allows a custom `MemoryManager` to be injected. + +If set, the `MemoryManager::Start` and `MemoryManager::Stop` methods will be +called at the start and end of benchmark runs to allow user code to fill out +a report on the number of allocations, bytes used, etc. + +This data will then be reported alongside other performance data, currently +only when using JSON output. + + + +## Profiling + +It's often useful to also profile benchmarks in particular ways, in addition to +CPU performance. For this reason, benchmark offers the `RegisterProfilerManager` +method that allows a custom `ProfilerManager` to be injected. + +If set, the `ProfilerManager::AfterSetupStart` and +`ProfilerManager::BeforeTeardownStop` methods will be called at the start and +end of a separate benchmark run to allow user code to collect and report +user-provided profile metrics. + +Output collected from this profiling run must be reported separately. + + + +## Using RegisterBenchmark(name, fn, args...) + +The `RegisterBenchmark(name, func, args...)` function provides an alternative +way to create and register benchmarks. +`RegisterBenchmark(name, func, args...)` creates, registers, and returns a +pointer to a new benchmark with the specified `name` that invokes +`func(st, args...)` where `st` is a `benchmark::State` object. + +Unlike the `BENCHMARK` registration macros, which can only be used at the global +scope, the `RegisterBenchmark` can be called anywhere. This allows for +benchmark tests to be registered programmatically. + +Additionally `RegisterBenchmark` allows any callable object to be registered +as a benchmark. Including capturing lambdas and function objects. + +For Example: +```c++ +auto BM_test = [](benchmark::State& st, auto Inputs) { /* ... */ }; + +int main(int argc, char** argv) { + for (auto& test_input : { /* ... */ }) + benchmark::RegisterBenchmark(test_input.name(), BM_test, test_input); + benchmark::Initialize(&argc, argv); + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); +} +``` + + + +## Exiting with an Error + +When errors caused by external influences, such as file I/O and network +communication, occur within a benchmark the +`State::SkipWithError(const std::string& msg)` function can be used to skip that run +of benchmark and report the error. Note that only future iterations of the +`KeepRunning()` are skipped. For the ranged-for version of the benchmark loop +Users must explicitly exit the loop, otherwise all iterations will be performed. +Users may explicitly return to exit the benchmark immediately. + +The `SkipWithError(...)` function may be used at any point within the benchmark, +including before and after the benchmark loop. Moreover, if `SkipWithError(...)` +has been used, it is not required to reach the benchmark loop and one may return +from the benchmark function early. + +For example: + +```c++ +static void BM_test(benchmark::State& state) { + auto resource = GetResource(); + if (!resource.good()) { + state.SkipWithError("Resource is not good!"); + // KeepRunning() loop will not be entered. + } + while (state.KeepRunning()) { + auto data = resource.read_data(); + if (!resource.good()) { + state.SkipWithError("Failed to read data!"); + break; // Needed to skip the rest of the iteration. + } + do_stuff(data); + } +} + +static void BM_test_ranged_fo(benchmark::State & state) { + auto resource = GetResource(); + if (!resource.good()) { + state.SkipWithError("Resource is not good!"); + return; // Early return is allowed when SkipWithError() has been used. + } + for (auto _ : state) { + auto data = resource.read_data(); + if (!resource.good()) { + state.SkipWithError("Failed to read data!"); + break; // REQUIRED to prevent all further iterations. + } + do_stuff(data); + } +} +``` + + +## A Faster KeepRunning Loop + +In C++11 mode, a ranged-based for loop should be used in preference to +the `KeepRunning` loop for running the benchmarks. For example: + +```c++ +static void BM_Fast(benchmark::State &state) { + for (auto _ : state) { + FastOperation(); + } +} +BENCHMARK(BM_Fast); +``` + +The reason the ranged-for loop is faster than using `KeepRunning`, is +because `KeepRunning` requires a memory load and store of the iteration count +ever iteration, whereas the ranged-for variant is able to keep the iteration count +in a register. + +For example, an empty inner loop of using the ranged-based for method looks like: + +```asm +# Loop Init + mov rbx, qword ptr [r14 + 104] + call benchmark::State::StartKeepRunning() + test rbx, rbx + je .LoopEnd +.LoopHeader: # =>This Inner Loop Header: Depth=1 + add rbx, -1 + jne .LoopHeader +.LoopEnd: +``` + +Compared to an empty `KeepRunning` loop, which looks like: + +```asm +.LoopHeader: # in Loop: Header=BB0_3 Depth=1 + cmp byte ptr [rbx], 1 + jne .LoopInit +.LoopBody: # =>This Inner Loop Header: Depth=1 + mov rax, qword ptr [rbx + 8] + lea rcx, [rax + 1] + mov qword ptr [rbx + 8], rcx + cmp rax, qword ptr [rbx + 104] + jb .LoopHeader + jmp .LoopEnd +.LoopInit: + mov rdi, rbx + call benchmark::State::StartKeepRunning() + jmp .LoopBody +.LoopEnd: +``` + +Unless C++03 compatibility is required, the ranged-for variant of writing +the benchmark loop should be preferred. + + + +## Disabling CPU Frequency Scaling + +If you see this error: + +``` +***WARNING*** CPU scaling is enabled, the benchmark real time measurements may +be noisy and will incur extra overhead. +``` + +you might want to disable the CPU frequency scaling while running the +benchmark, as well as consider other ways to stabilize the performance of +your system while benchmarking. + +See [Reducing Variance](reducing_variance.md) for more information. diff --git a/third_party/benchmark/include/benchmark/benchmark.h b/third_party/benchmark/include/benchmark/benchmark.h new file mode 100644 index 0000000..86f9dbb --- /dev/null +++ b/third_party/benchmark/include/benchmark/benchmark.h @@ -0,0 +1,2110 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Support for registering benchmarks for functions. + +/* Example usage: +// Define a function that executes the code to be measured a +// specified number of times: +static void BM_StringCreation(benchmark::State& state) { + for (auto _ : state) + std::string empty_string; +} + +// Register the function as a benchmark +BENCHMARK(BM_StringCreation); + +// Define another benchmark +static void BM_StringCopy(benchmark::State& state) { + std::string x = "hello"; + for (auto _ : state) + std::string copy(x); +} +BENCHMARK(BM_StringCopy); + +// Augment the main() program to invoke benchmarks if specified +// via the --benchmark_filter command line flag. E.g., +// my_unittest --benchmark_filter=all +// my_unittest --benchmark_filter=BM_StringCreation +// my_unittest --benchmark_filter=String +// my_unittest --benchmark_filter='Copy|Creation' +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); + return 0; +} + +// Sometimes a family of microbenchmarks can be implemented with +// just one routine that takes an extra argument to specify which +// one of the family of benchmarks to run. For example, the following +// code defines a family of microbenchmarks for measuring the speed +// of memcpy() calls of different lengths: + +static void BM_memcpy(benchmark::State& state) { + char* src = new char[state.range(0)]; char* dst = new char[state.range(0)]; + memset(src, 'x', state.range(0)); + for (auto _ : state) + memcpy(dst, src, state.range(0)); + state.SetBytesProcessed(state.iterations() * state.range(0)); + delete[] src; delete[] dst; +} +BENCHMARK(BM_memcpy)->Arg(8)->Arg(64)->Arg(512)->Arg(1<<10)->Arg(8<<10); + +// The preceding code is quite repetitive, and can be replaced with the +// following short-hand. The following invocation will pick a few +// appropriate arguments in the specified range and will generate a +// microbenchmark for each such argument. +BENCHMARK(BM_memcpy)->Range(8, 8<<10); + +// You might have a microbenchmark that depends on two inputs. For +// example, the following code defines a family of microbenchmarks for +// measuring the speed of set insertion. +static void BM_SetInsert(benchmark::State& state) { + set data; + for (auto _ : state) { + state.PauseTiming(); + data = ConstructRandomSet(state.range(0)); + state.ResumeTiming(); + for (int j = 0; j < state.range(1); ++j) + data.insert(RandomNumber()); + } +} +BENCHMARK(BM_SetInsert) + ->Args({1<<10, 128}) + ->Args({2<<10, 128}) + ->Args({4<<10, 128}) + ->Args({8<<10, 128}) + ->Args({1<<10, 512}) + ->Args({2<<10, 512}) + ->Args({4<<10, 512}) + ->Args({8<<10, 512}); + +// The preceding code is quite repetitive, and can be replaced with +// the following short-hand. The following macro will pick a few +// appropriate arguments in the product of the two specified ranges +// and will generate a microbenchmark for each such pair. +BENCHMARK(BM_SetInsert)->Ranges({{1<<10, 8<<10}, {128, 512}}); + +// For more complex patterns of inputs, passing a custom function +// to Apply allows programmatic specification of an +// arbitrary set of arguments to run the microbenchmark on. +// The following example enumerates a dense range on +// one parameter, and a sparse range on the second. +static void CustomArguments(benchmark::internal::Benchmark* b) { + for (int i = 0; i <= 10; ++i) + for (int j = 32; j <= 1024*1024; j *= 8) + b->Args({i, j}); +} +BENCHMARK(BM_SetInsert)->Apply(CustomArguments); + +// Templated microbenchmarks work the same way: +// Produce then consume 'size' messages 'iters' times +// Measures throughput in the absence of multiprogramming. +template int BM_Sequential(benchmark::State& state) { + Q q; + typename Q::value_type v; + for (auto _ : state) { + for (int i = state.range(0); i--; ) + q.push(v); + for (int e = state.range(0); e--; ) + q.Wait(&v); + } + // actually messages, not bytes: + state.SetBytesProcessed(state.iterations() * state.range(0)); +} +BENCHMARK_TEMPLATE(BM_Sequential, WaitQueue)->Range(1<<0, 1<<10); + +Use `Benchmark::MinTime(double t)` to set the minimum time used to run the +benchmark. This option overrides the `benchmark_min_time` flag. + +void BM_test(benchmark::State& state) { + ... body ... +} +BENCHMARK(BM_test)->MinTime(2.0); // Run for at least 2 seconds. + +In a multithreaded test, it is guaranteed that none of the threads will start +until all have reached the loop start, and all will have finished before any +thread exits the loop body. As such, any global setup or teardown you want to +do can be wrapped in a check against the thread index: + +static void BM_MultiThreaded(benchmark::State& state) { + if (state.thread_index() == 0) { + // Setup code here. + } + for (auto _ : state) { + // Run the test as normal. + } + if (state.thread_index() == 0) { + // Teardown code here. + } +} +BENCHMARK(BM_MultiThreaded)->Threads(4); + + +If a benchmark runs a few milliseconds it may be hard to visually compare the +measured times, since the output data is given in nanoseconds per default. In +order to manually set the time unit, you can specify it manually: + +BENCHMARK(BM_test)->Unit(benchmark::kMillisecond); +*/ + +#ifndef BENCHMARK_BENCHMARK_H_ +#define BENCHMARK_BENCHMARK_H_ + +// The _MSVC_LANG check should detect Visual Studio 2015 Update 3 and newer. +#if __cplusplus >= 201103L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201103L) +#define BENCHMARK_HAS_CXX11 +#endif + +// This _MSC_VER check should detect VS 2017 v15.3 and newer. +#if __cplusplus >= 201703L || \ + (defined(_MSC_VER) && _MSC_VER >= 1911 && _MSVC_LANG >= 201703L) +#define BENCHMARK_HAS_CXX17 +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/export.h" + +#if defined(BENCHMARK_HAS_CXX11) +#include +#include +#include +#include +#endif + +#if defined(_MSC_VER) +#include // for _ReadWriteBarrier +#endif + +#ifndef BENCHMARK_HAS_CXX11 +#define BENCHMARK_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&); \ + TypeName& operator=(const TypeName&) +#else +#define BENCHMARK_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + TypeName& operator=(const TypeName&) = delete +#endif + +#ifdef BENCHMARK_HAS_CXX17 +#define BENCHMARK_UNUSED [[maybe_unused]] +#elif defined(__GNUC__) || defined(__clang__) +#define BENCHMARK_UNUSED __attribute__((unused)) +#else +#define BENCHMARK_UNUSED +#endif + +// Used to annotate functions, methods and classes so they +// are not optimized by the compiler. Useful for tests +// where you expect loops to stay in place churning cycles +#if defined(__clang__) +#define BENCHMARK_DONT_OPTIMIZE __attribute__((optnone)) +#elif defined(__GNUC__) || defined(__GNUG__) +#define BENCHMARK_DONT_OPTIMIZE __attribute__((optimize(0))) +#else +// MSVC & Intel do not have a no-optimize attribute, only line pragmas +#define BENCHMARK_DONT_OPTIMIZE +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define BENCHMARK_ALWAYS_INLINE __attribute__((always_inline)) +#elif defined(_MSC_VER) && !defined(__clang__) +#define BENCHMARK_ALWAYS_INLINE __forceinline +#define __func__ __FUNCTION__ +#else +#define BENCHMARK_ALWAYS_INLINE +#endif + +#define BENCHMARK_INTERNAL_TOSTRING2(x) #x +#define BENCHMARK_INTERNAL_TOSTRING(x) BENCHMARK_INTERNAL_TOSTRING2(x) + +// clang-format off +#if (defined(__GNUC__) && !defined(__NVCC__) && !defined(__NVCOMPILER)) || defined(__clang__) +#define BENCHMARK_BUILTIN_EXPECT(x, y) __builtin_expect(x, y) +#define BENCHMARK_DEPRECATED_MSG(msg) __attribute__((deprecated(msg))) +#define BENCHMARK_DISABLE_DEPRECATED_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#define BENCHMARK_RESTORE_DEPRECATED_WARNING _Pragma("GCC diagnostic pop") +#elif defined(__NVCOMPILER) +#define BENCHMARK_BUILTIN_EXPECT(x, y) __builtin_expect(x, y) +#define BENCHMARK_DEPRECATED_MSG(msg) __attribute__((deprecated(msg))) +#define BENCHMARK_DISABLE_DEPRECATED_WARNING \ + _Pragma("diagnostic push") \ + _Pragma("diag_suppress deprecated_entity_with_custom_message") +#define BENCHMARK_RESTORE_DEPRECATED_WARNING _Pragma("diagnostic pop") +#else +#define BENCHMARK_BUILTIN_EXPECT(x, y) x +#define BENCHMARK_DEPRECATED_MSG(msg) +#define BENCHMARK_WARNING_MSG(msg) \ + __pragma(message(__FILE__ "(" BENCHMARK_INTERNAL_TOSTRING( \ + __LINE__) ") : warning note: " msg)) +#define BENCHMARK_DISABLE_DEPRECATED_WARNING +#define BENCHMARK_RESTORE_DEPRECATED_WARNING +#endif +// clang-format on + +#if defined(__GNUC__) && !defined(__clang__) +#define BENCHMARK_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +#endif + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + +#if defined(__GNUC__) || __has_builtin(__builtin_unreachable) +#define BENCHMARK_UNREACHABLE() __builtin_unreachable() +#elif defined(_MSC_VER) +#define BENCHMARK_UNREACHABLE() __assume(false) +#else +#define BENCHMARK_UNREACHABLE() ((void)0) +#endif + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_OVERRIDE override +#else +#define BENCHMARK_OVERRIDE +#endif + +#if defined(__GNUC__) +// Determine the cacheline size based on architecture +#if defined(__i386__) || defined(__x86_64__) +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 64 +#elif defined(__powerpc64__) +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 128 +#elif defined(__aarch64__) +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 64 +#elif defined(__arm__) +// Cache line sizes for ARM: These values are not strictly correct since +// cache line sizes depend on implementations, not architectures. There +// are even implementations with cache line sizes configurable at boot +// time. +#if defined(__ARM_ARCH_5T__) +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 32 +#elif defined(__ARM_ARCH_7A__) +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 64 +#endif // ARM_ARCH +#endif // arches +#endif // __GNUC__ + +#ifndef BENCHMARK_INTERNAL_CACHELINE_SIZE +// A reasonable default guess. Note that overestimates tend to waste more +// space, while underestimates tend to waste more time. +#define BENCHMARK_INTERNAL_CACHELINE_SIZE 64 +#endif + +#if defined(__GNUC__) +// Indicates that the declared object be cache aligned using +// `BENCHMARK_INTERNAL_CACHELINE_SIZE` (see above). +#define BENCHMARK_INTERNAL_CACHELINE_ALIGNED \ + __attribute__((aligned(BENCHMARK_INTERNAL_CACHELINE_SIZE))) +#elif defined(_MSC_VER) +#define BENCHMARK_INTERNAL_CACHELINE_ALIGNED \ + __declspec(align(BENCHMARK_INTERNAL_CACHELINE_SIZE)) +#else +#define BENCHMARK_INTERNAL_CACHELINE_ALIGNED +#endif + +#if defined(_MSC_VER) +#pragma warning(push) +// C4251: needs to have dll-interface to be used by clients of class +#pragma warning(disable : 4251) +#endif // _MSC_VER_ + +namespace benchmark { +class BenchmarkReporter; + +// Default number of minimum benchmark running time in seconds. +const char kDefaultMinTimeStr[] = "0.5s"; + +// Returns the version of the library. +BENCHMARK_EXPORT std::string GetBenchmarkVersion(); + +BENCHMARK_EXPORT void PrintDefaultHelp(); + +BENCHMARK_EXPORT void Initialize(int* argc, char** argv, + void (*HelperPrinterf)() = PrintDefaultHelp); +BENCHMARK_EXPORT void Shutdown(); + +// Report to stdout all arguments in 'argv' as unrecognized except the first. +// Returns true there is at least on unrecognized argument (i.e. 'argc' > 1). +BENCHMARK_EXPORT bool ReportUnrecognizedArguments(int argc, char** argv); + +// Returns the current value of --benchmark_filter. +BENCHMARK_EXPORT std::string GetBenchmarkFilter(); + +// Sets a new value to --benchmark_filter. (This will override this flag's +// current value). +// Should be called after `benchmark::Initialize()`, as +// `benchmark::Initialize()` will override the flag's value. +BENCHMARK_EXPORT void SetBenchmarkFilter(std::string value); + +// Returns the current value of --v (command line value for verbosity). +BENCHMARK_EXPORT int32_t GetBenchmarkVerbosity(); + +// Creates a default display reporter. Used by the library when no display +// reporter is provided, but also made available for external use in case a +// custom reporter should respect the `--benchmark_format` flag as a fallback +BENCHMARK_EXPORT BenchmarkReporter* CreateDefaultDisplayReporter(); + +// Generate a list of benchmarks matching the specified --benchmark_filter flag +// and if --benchmark_list_tests is specified return after printing the name +// of each matching benchmark. Otherwise run each matching benchmark and +// report the results. +// +// spec : Specify the benchmarks to run. If users do not specify this arg, +// then the value of FLAGS_benchmark_filter +// will be used. +// +// The second and third overload use the specified 'display_reporter' and +// 'file_reporter' respectively. 'file_reporter' will write to the file +// specified +// by '--benchmark_out'. If '--benchmark_out' is not given the +// 'file_reporter' is ignored. +// +// RETURNS: The number of matching benchmarks. +BENCHMARK_EXPORT size_t RunSpecifiedBenchmarks(); +BENCHMARK_EXPORT size_t RunSpecifiedBenchmarks(std::string spec); + +BENCHMARK_EXPORT size_t +RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter); +BENCHMARK_EXPORT size_t +RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter, std::string spec); + +BENCHMARK_EXPORT size_t RunSpecifiedBenchmarks( + BenchmarkReporter* display_reporter, BenchmarkReporter* file_reporter); +BENCHMARK_EXPORT size_t +RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter, + BenchmarkReporter* file_reporter, std::string spec); + +// TimeUnit is passed to a benchmark in order to specify the order of magnitude +// for the measured time. +enum TimeUnit { kNanosecond, kMicrosecond, kMillisecond, kSecond }; + +BENCHMARK_EXPORT TimeUnit GetDefaultTimeUnit(); + +// Sets the default time unit the benchmarks use +// Has to be called before the benchmark loop to take effect +BENCHMARK_EXPORT void SetDefaultTimeUnit(TimeUnit unit); + +// If a MemoryManager is registered (via RegisterMemoryManager()), +// it can be used to collect and report allocation metrics for a run of the +// benchmark. +class MemoryManager { + public: + static const int64_t TombstoneValue; + + struct Result { + Result() + : num_allocs(0), + max_bytes_used(0), + total_allocated_bytes(TombstoneValue), + net_heap_growth(TombstoneValue) {} + + // The number of allocations made in total between Start and Stop. + int64_t num_allocs; + + // The peak memory use between Start and Stop. + int64_t max_bytes_used; + + // The total memory allocated, in bytes, between Start and Stop. + // Init'ed to TombstoneValue if metric not available. + int64_t total_allocated_bytes; + + // The net changes in memory, in bytes, between Start and Stop. + // ie., total_allocated_bytes - total_deallocated_bytes. + // Init'ed to TombstoneValue if metric not available. + int64_t net_heap_growth; + }; + + virtual ~MemoryManager() {} + + // Implement this to start recording allocation information. + virtual void Start() = 0; + + // Implement this to stop recording and fill out the given Result structure. + virtual void Stop(Result& result) = 0; +}; + +// Register a MemoryManager instance that will be used to collect and report +// allocation measurements for benchmark runs. +BENCHMARK_EXPORT +void RegisterMemoryManager(MemoryManager* memory_manager); + +// If a ProfilerManager is registered (via RegisterProfilerManager()), the +// benchmark will be run an additional time under the profiler to collect and +// report profile metrics for the run of the benchmark. +class ProfilerManager { + public: + virtual ~ProfilerManager() {} + + // This is called after `Setup()` code and right before the benchmark is run. + virtual void AfterSetupStart() = 0; + + // This is called before `Teardown()` code and right after the benchmark + // completes. + virtual void BeforeTeardownStop() = 0; +}; + +// Register a ProfilerManager instance that will be used to collect and report +// profile measurements for benchmark runs. +BENCHMARK_EXPORT +void RegisterProfilerManager(ProfilerManager* profiler_manager); + +// Add a key-value pair to output as part of the context stanza in the report. +BENCHMARK_EXPORT +void AddCustomContext(const std::string& key, const std::string& value); + +namespace internal { +class Benchmark; +class BenchmarkImp; +class BenchmarkFamilies; + +BENCHMARK_EXPORT std::map*& GetGlobalContext(); + +BENCHMARK_EXPORT +void UseCharPointer(char const volatile*); + +// Take ownership of the pointer and register the benchmark. Return the +// registered benchmark. +BENCHMARK_EXPORT Benchmark* RegisterBenchmarkInternal(Benchmark*); + +// Ensure that the standard streams are properly initialized in every TU. +BENCHMARK_EXPORT int InitializeStreams(); +BENCHMARK_UNUSED static int stream_init_anchor = InitializeStreams(); + +} // namespace internal + +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__pnacl__) || \ + defined(__EMSCRIPTEN__) +#define BENCHMARK_HAS_NO_INLINE_ASSEMBLY +#endif + +// Force the compiler to flush pending writes to global memory. Acts as an +// effective read/write barrier +#ifdef BENCHMARK_HAS_CXX11 +inline BENCHMARK_ALWAYS_INLINE void ClobberMemory() { + std::atomic_signal_fence(std::memory_order_acq_rel); +} +#endif + +// The DoNotOptimize(...) function can be used to prevent a value or +// expression from being optimized away by the compiler. This function is +// intended to add little to no overhead. +// See: https://youtu.be/nXaxk27zwlk?t=2441 +#ifndef BENCHMARK_HAS_NO_INLINE_ASSEMBLY +#if !defined(__GNUC__) || defined(__llvm__) || defined(__INTEL_COMPILER) +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp const& value) { + asm volatile("" : : "r,m"(value) : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp& value) { +#if defined(__clang__) + asm volatile("" : "+r,m"(value) : : "memory"); +#else + asm volatile("" : "+m,r"(value) : : "memory"); +#endif +} + +#ifdef BENCHMARK_HAS_CXX11 +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp&& value) { +#if defined(__clang__) + asm volatile("" : "+r,m"(value) : : "memory"); +#else + asm volatile("" : "+m,r"(value) : : "memory"); +#endif +} +#endif +#elif defined(BENCHMARK_HAS_CXX11) && (__GNUC__ >= 5) +// Workaround for a bug with full argument copy overhead with GCC. +// See: #1340 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105519 +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value && + (sizeof(Tp) <= sizeof(Tp*))>::type + DoNotOptimize(Tp const& value) { + asm volatile("" : : "r,m"(value) : "memory"); +} + +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value || + (sizeof(Tp) > sizeof(Tp*))>::type + DoNotOptimize(Tp const& value) { + asm volatile("" : : "m"(value) : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value && + (sizeof(Tp) <= sizeof(Tp*))>::type + DoNotOptimize(Tp& value) { + asm volatile("" : "+m,r"(value) : : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value || + (sizeof(Tp) > sizeof(Tp*))>::type + DoNotOptimize(Tp& value) { + asm volatile("" : "+m"(value) : : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value && + (sizeof(Tp) <= sizeof(Tp*))>::type + DoNotOptimize(Tp&& value) { + asm volatile("" : "+m,r"(value) : : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE + typename std::enable_if::value || + (sizeof(Tp) > sizeof(Tp*))>::type + DoNotOptimize(Tp&& value) { + asm volatile("" : "+m"(value) : : "memory"); +} + +#else +// Fallback for GCC < 5. Can add some overhead because the compiler is forced +// to use memory operations instead of operations with registers. +// TODO: Remove if GCC < 5 will be unsupported. +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp const& value) { + asm volatile("" : : "m"(value) : "memory"); +} + +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp& value) { + asm volatile("" : "+m"(value) : : "memory"); +} + +#ifdef BENCHMARK_HAS_CXX11 +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp&& value) { + asm volatile("" : "+m"(value) : : "memory"); +} +#endif +#endif + +#ifndef BENCHMARK_HAS_CXX11 +inline BENCHMARK_ALWAYS_INLINE void ClobberMemory() { + asm volatile("" : : : "memory"); +} +#endif +#elif defined(_MSC_VER) +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp const& value) { + internal::UseCharPointer(&reinterpret_cast(value)); + _ReadWriteBarrier(); +} + +#ifndef BENCHMARK_HAS_CXX11 +inline BENCHMARK_ALWAYS_INLINE void ClobberMemory() { _ReadWriteBarrier(); } +#endif +#else +#ifdef BENCHMARK_HAS_CXX11 +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp&& value) { + internal::UseCharPointer(&reinterpret_cast(value)); +} +#else +template +BENCHMARK_DEPRECATED_MSG( + "The const-ref version of this method can permit " + "undesired compiler optimizations in benchmarks") +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp const& value) { + internal::UseCharPointer(&reinterpret_cast(value)); +} + +template +inline BENCHMARK_ALWAYS_INLINE void DoNotOptimize(Tp& value) { + internal::UseCharPointer(&reinterpret_cast(value)); +} +#endif +// FIXME Add ClobberMemory() for non-gnu and non-msvc compilers, before C++11. +#endif + +// This class is used for user-defined counters. +class Counter { + public: + enum Flags { + kDefaults = 0, + // Mark the counter as a rate. It will be presented divided + // by the duration of the benchmark. + kIsRate = 1 << 0, + // Mark the counter as a thread-average quantity. It will be + // presented divided by the number of threads. + kAvgThreads = 1 << 1, + // Mark the counter as a thread-average rate. See above. + kAvgThreadsRate = kIsRate | kAvgThreads, + // Mark the counter as a constant value, valid/same for *every* iteration. + // When reporting, it will be *multiplied* by the iteration count. + kIsIterationInvariant = 1 << 2, + // Mark the counter as a constant rate. + // When reporting, it will be *multiplied* by the iteration count + // and then divided by the duration of the benchmark. + kIsIterationInvariantRate = kIsRate | kIsIterationInvariant, + // Mark the counter as a iteration-average quantity. + // It will be presented divided by the number of iterations. + kAvgIterations = 1 << 3, + // Mark the counter as a iteration-average rate. See above. + kAvgIterationsRate = kIsRate | kAvgIterations, + + // In the end, invert the result. This is always done last! + kInvert = 1 << 31 + }; + + enum OneK { + // 1'000 items per 1k + kIs1000 = 1000, + // 1'024 items per 1k + kIs1024 = 1024 + }; + + double value; + Flags flags; + OneK oneK; + + BENCHMARK_ALWAYS_INLINE + Counter(double v = 0., Flags f = kDefaults, OneK k = kIs1000) + : value(v), flags(f), oneK(k) {} + + BENCHMARK_ALWAYS_INLINE operator double const &() const { return value; } + BENCHMARK_ALWAYS_INLINE operator double&() { return value; } +}; + +// A helper for user code to create unforeseen combinations of Flags, without +// having to do this cast manually each time, or providing this operator. +Counter::Flags inline operator|(const Counter::Flags& LHS, + const Counter::Flags& RHS) { + return static_cast(static_cast(LHS) | + static_cast(RHS)); +} + +// This is the container for the user-defined counters. +typedef std::map UserCounters; + +// BigO is passed to a benchmark in order to specify the asymptotic +// computational +// complexity for the benchmark. In case oAuto is selected, complexity will be +// calculated automatically to the best fit. +enum BigO { oNone, o1, oN, oNSquared, oNCubed, oLogN, oNLogN, oAuto, oLambda }; + +typedef int64_t ComplexityN; + +typedef int64_t IterationCount; + +enum StatisticUnit { kTime, kPercentage }; + +// BigOFunc is passed to a benchmark in order to specify the asymptotic +// computational complexity for the benchmark. +typedef double(BigOFunc)(ComplexityN); + +// StatisticsFunc is passed to a benchmark in order to compute some descriptive +// statistics over all the measurements of some type +typedef double(StatisticsFunc)(const std::vector&); + +namespace internal { +struct Statistics { + std::string name_; + StatisticsFunc* compute_; + StatisticUnit unit_; + + Statistics(const std::string& name, StatisticsFunc* compute, + StatisticUnit unit = kTime) + : name_(name), compute_(compute), unit_(unit) {} +}; + +class BenchmarkInstance; +class ThreadTimer; +class ThreadManager; +class PerfCountersMeasurement; + +enum AggregationReportMode +#if defined(BENCHMARK_HAS_CXX11) + : unsigned +#else +#endif +{ + // The mode has not been manually specified + ARM_Unspecified = 0, + // The mode is user-specified. + // This may or may not be set when the following bit-flags are set. + ARM_Default = 1U << 0U, + // File reporter should only output aggregates. + ARM_FileReportAggregatesOnly = 1U << 1U, + // Display reporter should only output aggregates + ARM_DisplayReportAggregatesOnly = 1U << 2U, + // Both reporters should only display aggregates. + ARM_ReportAggregatesOnly = + ARM_FileReportAggregatesOnly | ARM_DisplayReportAggregatesOnly +}; + +enum Skipped +#if defined(BENCHMARK_HAS_CXX11) + : unsigned +#endif +{ + NotSkipped = 0, + SkippedWithMessage, + SkippedWithError +}; + +} // namespace internal + +#if defined(_MSC_VER) +#pragma warning(push) +// C4324: 'benchmark::State': structure was padded due to alignment specifier +#pragma warning(disable : 4324) +#endif // _MSC_VER_ +// State is passed to a running Benchmark and contains state for the +// benchmark to use. +class BENCHMARK_EXPORT BENCHMARK_INTERNAL_CACHELINE_ALIGNED State { + public: + struct StateIterator; + friend struct StateIterator; + + // Returns iterators used to run each iteration of a benchmark using a + // C++11 ranged-based for loop. These functions should not be called directly. + // + // REQUIRES: The benchmark has not started running yet. Neither begin nor end + // have been called previously. + // + // NOTE: KeepRunning may not be used after calling either of these functions. + inline BENCHMARK_ALWAYS_INLINE StateIterator begin(); + inline BENCHMARK_ALWAYS_INLINE StateIterator end(); + + // Returns true if the benchmark should continue through another iteration. + // NOTE: A benchmark may not return from the test until KeepRunning() has + // returned false. + inline bool KeepRunning(); + + // Returns true iff the benchmark should run n more iterations. + // REQUIRES: 'n' > 0. + // NOTE: A benchmark must not return from the test until KeepRunningBatch() + // has returned false. + // NOTE: KeepRunningBatch() may overshoot by up to 'n' iterations. + // + // Intended usage: + // while (state.KeepRunningBatch(1000)) { + // // process 1000 elements + // } + inline bool KeepRunningBatch(IterationCount n); + + // REQUIRES: timer is running and 'SkipWithMessage(...)' or + // 'SkipWithError(...)' has not been called by the current thread. + // Stop the benchmark timer. If not called, the timer will be + // automatically stopped after the last iteration of the benchmark loop. + // + // For threaded benchmarks the PauseTiming() function only pauses the timing + // for the current thread. + // + // NOTE: The "real time" measurement is per-thread. If different threads + // report different measurements the largest one is reported. + // + // NOTE: PauseTiming()/ResumeTiming() are relatively + // heavyweight, and so their use should generally be avoided + // within each benchmark iteration, if possible. + void PauseTiming(); + + // REQUIRES: timer is not running and 'SkipWithMessage(...)' or + // 'SkipWithError(...)' has not been called by the current thread. + // Start the benchmark timer. The timer is NOT running on entrance to the + // benchmark function. It begins running after control flow enters the + // benchmark loop. + // + // NOTE: PauseTiming()/ResumeTiming() are relatively + // heavyweight, and so their use should generally be avoided + // within each benchmark iteration, if possible. + void ResumeTiming(); + + // REQUIRES: 'SkipWithMessage(...)' or 'SkipWithError(...)' has not been + // called previously by the current thread. + // Report the benchmark as resulting in being skipped with the specified + // 'msg'. + // After this call the user may explicitly 'return' from the benchmark. + // + // If the ranged-for style of benchmark loop is used, the user must explicitly + // break from the loop, otherwise all future iterations will be run. + // If the 'KeepRunning()' loop is used the current thread will automatically + // exit the loop at the end of the current iteration. + // + // For threaded benchmarks only the current thread stops executing and future + // calls to `KeepRunning()` will block until all threads have completed + // the `KeepRunning()` loop. If multiple threads report being skipped only the + // first skip message is used. + // + // NOTE: Calling 'SkipWithMessage(...)' does not cause the benchmark to exit + // the current scope immediately. If the function is called from within + // the 'KeepRunning()' loop the current iteration will finish. It is the users + // responsibility to exit the scope as needed. + void SkipWithMessage(const std::string& msg); + + // REQUIRES: 'SkipWithMessage(...)' or 'SkipWithError(...)' has not been + // called previously by the current thread. + // Report the benchmark as resulting in an error with the specified 'msg'. + // After this call the user may explicitly 'return' from the benchmark. + // + // If the ranged-for style of benchmark loop is used, the user must explicitly + // break from the loop, otherwise all future iterations will be run. + // If the 'KeepRunning()' loop is used the current thread will automatically + // exit the loop at the end of the current iteration. + // + // For threaded benchmarks only the current thread stops executing and future + // calls to `KeepRunning()` will block until all threads have completed + // the `KeepRunning()` loop. If multiple threads report an error only the + // first error message is used. + // + // NOTE: Calling 'SkipWithError(...)' does not cause the benchmark to exit + // the current scope immediately. If the function is called from within + // the 'KeepRunning()' loop the current iteration will finish. It is the users + // responsibility to exit the scope as needed. + void SkipWithError(const std::string& msg); + + // Returns true if 'SkipWithMessage(...)' or 'SkipWithError(...)' was called. + bool skipped() const { return internal::NotSkipped != skipped_; } + + // Returns true if an error has been reported with 'SkipWithError(...)'. + bool error_occurred() const { return internal::SkippedWithError == skipped_; } + + // REQUIRES: called exactly once per iteration of the benchmarking loop. + // Set the manually measured time for this benchmark iteration, which + // is used instead of automatically measured time if UseManualTime() was + // specified. + // + // For threaded benchmarks the final value will be set to the largest + // reported values. + void SetIterationTime(double seconds); + + // Set the number of bytes processed by the current benchmark + // execution. This routine is typically called once at the end of a + // throughput oriented benchmark. + // + // REQUIRES: a benchmark has exited its benchmarking loop. + BENCHMARK_ALWAYS_INLINE + void SetBytesProcessed(int64_t bytes) { + counters["bytes_per_second"] = + Counter(static_cast(bytes), Counter::kIsRate, Counter::kIs1024); + } + + BENCHMARK_ALWAYS_INLINE + int64_t bytes_processed() const { + if (counters.find("bytes_per_second") != counters.end()) + return static_cast(counters.at("bytes_per_second")); + return 0; + } + + // If this routine is called with complexity_n > 0 and complexity report is + // requested for the + // family benchmark, then current benchmark will be part of the computation + // and complexity_n will + // represent the length of N. + BENCHMARK_ALWAYS_INLINE + void SetComplexityN(ComplexityN complexity_n) { + complexity_n_ = complexity_n; + } + + BENCHMARK_ALWAYS_INLINE + ComplexityN complexity_length_n() const { return complexity_n_; } + + // If this routine is called with items > 0, then an items/s + // label is printed on the benchmark report line for the currently + // executing benchmark. It is typically called at the end of a processing + // benchmark where a processing items/second output is desired. + // + // REQUIRES: a benchmark has exited its benchmarking loop. + BENCHMARK_ALWAYS_INLINE + void SetItemsProcessed(int64_t items) { + counters["items_per_second"] = + Counter(static_cast(items), benchmark::Counter::kIsRate); + } + + BENCHMARK_ALWAYS_INLINE + int64_t items_processed() const { + if (counters.find("items_per_second") != counters.end()) + return static_cast(counters.at("items_per_second")); + return 0; + } + + // If this routine is called, the specified label is printed at the + // end of the benchmark report line for the currently executing + // benchmark. Example: + // static void BM_Compress(benchmark::State& state) { + // ... + // double compress = input_size / output_size; + // state.SetLabel(StrFormat("compress:%.1f%%", 100.0*compression)); + // } + // Produces output that looks like: + // BM_Compress 50 50 14115038 compress:27.3% + // + // REQUIRES: a benchmark has exited its benchmarking loop. + void SetLabel(const std::string& label); + + // Range arguments for this run. CHECKs if the argument has been set. + BENCHMARK_ALWAYS_INLINE + int64_t range(std::size_t pos = 0) const { + assert(range_.size() > pos); + return range_[pos]; + } + + BENCHMARK_DEPRECATED_MSG("use 'range(0)' instead") + int64_t range_x() const { return range(0); } + + BENCHMARK_DEPRECATED_MSG("use 'range(1)' instead") + int64_t range_y() const { return range(1); } + + // Number of threads concurrently executing the benchmark. + BENCHMARK_ALWAYS_INLINE + int threads() const { return threads_; } + + // Index of the executing thread. Values from [0, threads). + BENCHMARK_ALWAYS_INLINE + int thread_index() const { return thread_index_; } + + BENCHMARK_ALWAYS_INLINE + IterationCount iterations() const { + if (BENCHMARK_BUILTIN_EXPECT(!started_, false)) { + return 0; + } + return max_iterations - total_iterations_ + batch_leftover_; + } + + BENCHMARK_ALWAYS_INLINE + std::string name() const { return name_; } + + private: + // items we expect on the first cache line (ie 64 bytes of the struct) + // When total_iterations_ is 0, KeepRunning() and friends will return false. + // May be larger than max_iterations. + IterationCount total_iterations_; + + // When using KeepRunningBatch(), batch_leftover_ holds the number of + // iterations beyond max_iters that were run. Used to track + // completed_iterations_ accurately. + IterationCount batch_leftover_; + + public: + const IterationCount max_iterations; + + private: + bool started_; + bool finished_; + internal::Skipped skipped_; + + // items we don't need on the first cache line + std::vector range_; + + ComplexityN complexity_n_; + + public: + // Container for user-defined counters. + UserCounters counters; + + private: + State(std::string name, IterationCount max_iters, + const std::vector& ranges, int thread_i, int n_threads, + internal::ThreadTimer* timer, internal::ThreadManager* manager, + internal::PerfCountersMeasurement* perf_counters_measurement, + ProfilerManager* profiler_manager); + + void StartKeepRunning(); + // Implementation of KeepRunning() and KeepRunningBatch(). + // is_batch must be true unless n is 1. + inline bool KeepRunningInternal(IterationCount n, bool is_batch); + void FinishKeepRunning(); + + const std::string name_; + const int thread_index_; + const int threads_; + + internal::ThreadTimer* const timer_; + internal::ThreadManager* const manager_; + internal::PerfCountersMeasurement* const perf_counters_measurement_; + ProfilerManager* const profiler_manager_; + + friend class internal::BenchmarkInstance; +}; +#if defined(_MSC_VER) +#pragma warning(pop) +#endif // _MSC_VER_ + +inline BENCHMARK_ALWAYS_INLINE bool State::KeepRunning() { + return KeepRunningInternal(1, /*is_batch=*/false); +} + +inline BENCHMARK_ALWAYS_INLINE bool State::KeepRunningBatch(IterationCount n) { + return KeepRunningInternal(n, /*is_batch=*/true); +} + +inline BENCHMARK_ALWAYS_INLINE bool State::KeepRunningInternal(IterationCount n, + bool is_batch) { + // total_iterations_ is set to 0 by the constructor, and always set to a + // nonzero value by StartKepRunning(). + assert(n > 0); + // n must be 1 unless is_batch is true. + assert(is_batch || n == 1); + if (BENCHMARK_BUILTIN_EXPECT(total_iterations_ >= n, true)) { + total_iterations_ -= n; + return true; + } + if (!started_) { + StartKeepRunning(); + if (!skipped() && total_iterations_ >= n) { + total_iterations_ -= n; + return true; + } + } + // For non-batch runs, total_iterations_ must be 0 by now. + if (is_batch && total_iterations_ != 0) { + batch_leftover_ = n - total_iterations_; + total_iterations_ = 0; + return true; + } + FinishKeepRunning(); + return false; +} + +struct State::StateIterator { + struct BENCHMARK_UNUSED Value {}; + typedef std::forward_iterator_tag iterator_category; + typedef Value value_type; + typedef Value reference; + typedef Value pointer; + typedef std::ptrdiff_t difference_type; + + private: + friend class State; + BENCHMARK_ALWAYS_INLINE + StateIterator() : cached_(0), parent_() {} + + BENCHMARK_ALWAYS_INLINE + explicit StateIterator(State* st) + : cached_(st->skipped() ? 0 : st->max_iterations), parent_(st) {} + + public: + BENCHMARK_ALWAYS_INLINE + Value operator*() const { return Value(); } + + BENCHMARK_ALWAYS_INLINE + StateIterator& operator++() { + assert(cached_ > 0); + --cached_; + return *this; + } + + BENCHMARK_ALWAYS_INLINE + bool operator!=(StateIterator const&) const { + if (BENCHMARK_BUILTIN_EXPECT(cached_ != 0, true)) return true; + parent_->FinishKeepRunning(); + return false; + } + + private: + IterationCount cached_; + State* const parent_; +}; + +inline BENCHMARK_ALWAYS_INLINE State::StateIterator State::begin() { + return StateIterator(this); +} +inline BENCHMARK_ALWAYS_INLINE State::StateIterator State::end() { + StartKeepRunning(); + return StateIterator(); +} + +namespace internal { + +typedef void(Function)(State&); + +// ------------------------------------------------------ +// Benchmark registration object. The BENCHMARK() macro expands +// into an internal::Benchmark* object. Various methods can +// be called on this object to change the properties of the benchmark. +// Each method returns "this" so that multiple method calls can +// chained into one expression. +class BENCHMARK_EXPORT Benchmark { + public: + virtual ~Benchmark(); + + // Note: the following methods all return "this" so that multiple + // method calls can be chained together in one expression. + + // Specify the name of the benchmark + Benchmark* Name(const std::string& name); + + // Run this benchmark once with "x" as the extra argument passed + // to the function. + // REQUIRES: The function passed to the constructor must accept an arg1. + Benchmark* Arg(int64_t x); + + // Run this benchmark with the given time unit for the generated output report + Benchmark* Unit(TimeUnit unit); + + // Run this benchmark once for a number of values picked from the + // range [start..limit]. (start and limit are always picked.) + // REQUIRES: The function passed to the constructor must accept an arg1. + Benchmark* Range(int64_t start, int64_t limit); + + // Run this benchmark once for all values in the range [start..limit] with + // specific step + // REQUIRES: The function passed to the constructor must accept an arg1. + Benchmark* DenseRange(int64_t start, int64_t limit, int step = 1); + + // Run this benchmark once with "args" as the extra arguments passed + // to the function. + // REQUIRES: The function passed to the constructor must accept arg1, arg2 ... + Benchmark* Args(const std::vector& args); + + // Equivalent to Args({x, y}) + // NOTE: This is a legacy C++03 interface provided for compatibility only. + // New code should use 'Args'. + Benchmark* ArgPair(int64_t x, int64_t y) { + std::vector args; + args.push_back(x); + args.push_back(y); + return Args(args); + } + + // Run this benchmark once for a number of values picked from the + // ranges [start..limit]. (starts and limits are always picked.) + // REQUIRES: The function passed to the constructor must accept arg1, arg2 ... + Benchmark* Ranges(const std::vector >& ranges); + + // Run this benchmark once for each combination of values in the (cartesian) + // product of the supplied argument lists. + // REQUIRES: The function passed to the constructor must accept arg1, arg2 ... + Benchmark* ArgsProduct(const std::vector >& arglists); + + // Equivalent to ArgNames({name}) + Benchmark* ArgName(const std::string& name); + + // Set the argument names to display in the benchmark name. If not called, + // only argument values will be shown. + Benchmark* ArgNames(const std::vector& names); + + // Equivalent to Ranges({{lo1, hi1}, {lo2, hi2}}). + // NOTE: This is a legacy C++03 interface provided for compatibility only. + // New code should use 'Ranges'. + Benchmark* RangePair(int64_t lo1, int64_t hi1, int64_t lo2, int64_t hi2) { + std::vector > ranges; + ranges.push_back(std::make_pair(lo1, hi1)); + ranges.push_back(std::make_pair(lo2, hi2)); + return Ranges(ranges); + } + + // Have "setup" and/or "teardown" invoked once for every benchmark run. + // If the benchmark is multi-threaded (will run in k threads concurrently), + // the setup callback will be be invoked exactly once (not k times) before + // each run with k threads. Time allowing (e.g. for a short benchmark), there + // may be multiple such runs per benchmark, each run with its own + // "setup"/"teardown". + // + // If the benchmark uses different size groups of threads (e.g. via + // ThreadRange), the above will be true for each size group. + // + // The callback will be passed a State object, which includes the number + // of threads, thread-index, benchmark arguments, etc. + // + // The callback must not be NULL or self-deleting. + Benchmark* Setup(void (*setup)(const benchmark::State&)); + Benchmark* Teardown(void (*teardown)(const benchmark::State&)); + + // Pass this benchmark object to *func, which can customize + // the benchmark by calling various methods like Arg, Args, + // Threads, etc. + Benchmark* Apply(void (*func)(Benchmark* benchmark)); + + // Set the range multiplier for non-dense range. If not called, the range + // multiplier kRangeMultiplier will be used. + Benchmark* RangeMultiplier(int multiplier); + + // Set the minimum amount of time to use when running this benchmark. This + // option overrides the `benchmark_min_time` flag. + // REQUIRES: `t > 0` and `Iterations` has not been called on this benchmark. + Benchmark* MinTime(double t); + + // Set the minimum amount of time to run the benchmark before taking runtimes + // of this benchmark into account. This + // option overrides the `benchmark_min_warmup_time` flag. + // REQUIRES: `t >= 0` and `Iterations` has not been called on this benchmark. + Benchmark* MinWarmUpTime(double t); + + // Specify the amount of iterations that should be run by this benchmark. + // This option overrides the `benchmark_min_time` flag. + // REQUIRES: 'n > 0' and `MinTime` has not been called on this benchmark. + // + // NOTE: This function should only be used when *exact* iteration control is + // needed and never to control or limit how long a benchmark runs, where + // `--benchmark_min_time=s` or `MinTime(...)` should be used instead. + Benchmark* Iterations(IterationCount n); + + // Specify the amount of times to repeat this benchmark. This option overrides + // the `benchmark_repetitions` flag. + // REQUIRES: `n > 0` + Benchmark* Repetitions(int n); + + // Specify if each repetition of the benchmark should be reported separately + // or if only the final statistics should be reported. If the benchmark + // is not repeated then the single result is always reported. + // Applies to *ALL* reporters (display and file). + Benchmark* ReportAggregatesOnly(bool value = true); + + // Same as ReportAggregatesOnly(), but applies to display reporter only. + Benchmark* DisplayAggregatesOnly(bool value = true); + + // By default, the CPU time is measured only for the main thread, which may + // be unrepresentative if the benchmark uses threads internally. If called, + // the total CPU time spent by all the threads will be measured instead. + // By default, only the main thread CPU time will be measured. + Benchmark* MeasureProcessCPUTime(); + + // If a particular benchmark should use the Wall clock instead of the CPU time + // (be it either the CPU time of the main thread only (default), or the + // total CPU usage of the benchmark), call this method. If called, the elapsed + // (wall) time will be used to control how many iterations are run, and in the + // printing of items/second or MB/seconds values. + // If not called, the CPU time used by the benchmark will be used. + Benchmark* UseRealTime(); + + // If a benchmark must measure time manually (e.g. if GPU execution time is + // being + // measured), call this method. If called, each benchmark iteration should + // call + // SetIterationTime(seconds) to report the measured time, which will be used + // to control how many iterations are run, and in the printing of items/second + // or MB/second values. + Benchmark* UseManualTime(); + + // Set the asymptotic computational complexity for the benchmark. If called + // the asymptotic computational complexity will be shown on the output. + Benchmark* Complexity(BigO complexity = benchmark::oAuto); + + // Set the asymptotic computational complexity for the benchmark. If called + // the asymptotic computational complexity will be shown on the output. + Benchmark* Complexity(BigOFunc* complexity); + + // Add this statistics to be computed over all the values of benchmark run + Benchmark* ComputeStatistics(const std::string& name, + StatisticsFunc* statistics, + StatisticUnit unit = kTime); + + // Support for running multiple copies of the same benchmark concurrently + // in multiple threads. This may be useful when measuring the scaling + // of some piece of code. + + // Run one instance of this benchmark concurrently in t threads. + Benchmark* Threads(int t); + + // Pick a set of values T from [min_threads,max_threads]. + // min_threads and max_threads are always included in T. Run this + // benchmark once for each value in T. The benchmark run for a + // particular value t consists of t threads running the benchmark + // function concurrently. For example, consider: + // BENCHMARK(Foo)->ThreadRange(1,16); + // This will run the following benchmarks: + // Foo in 1 thread + // Foo in 2 threads + // Foo in 4 threads + // Foo in 8 threads + // Foo in 16 threads + Benchmark* ThreadRange(int min_threads, int max_threads); + + // For each value n in the range, run this benchmark once using n threads. + // min_threads and max_threads are always included in the range. + // stride specifies the increment. E.g. DenseThreadRange(1, 8, 3) starts + // a benchmark with 1, 4, 7 and 8 threads. + Benchmark* DenseThreadRange(int min_threads, int max_threads, int stride = 1); + + // Equivalent to ThreadRange(NumCPUs(), NumCPUs()) + Benchmark* ThreadPerCpu(); + + virtual void Run(State& state) = 0; + + TimeUnit GetTimeUnit() const; + + protected: + explicit Benchmark(const std::string& name); + void SetName(const std::string& name); + + public: + const char* GetName() const; + int ArgsCnt() const; + const char* GetArgName(int arg) const; + + private: + friend class BenchmarkFamilies; + friend class BenchmarkInstance; + + std::string name_; + AggregationReportMode aggregation_report_mode_; + std::vector arg_names_; // Args for all benchmark runs + std::vector > args_; // Args for all benchmark runs + + TimeUnit time_unit_; + bool use_default_time_unit_; + + int range_multiplier_; + double min_time_; + double min_warmup_time_; + IterationCount iterations_; + int repetitions_; + bool measure_process_cpu_time_; + bool use_real_time_; + bool use_manual_time_; + BigO complexity_; + BigOFunc* complexity_lambda_; + std::vector statistics_; + std::vector thread_counts_; + + typedef void (*callback_function)(const benchmark::State&); + callback_function setup_; + callback_function teardown_; + + Benchmark(Benchmark const&) +#if defined(BENCHMARK_HAS_CXX11) + = delete +#endif + ; + + Benchmark& operator=(Benchmark const&) +#if defined(BENCHMARK_HAS_CXX11) + = delete +#endif + ; +}; + +} // namespace internal + +// Create and register a benchmark with the specified 'name' that invokes +// the specified functor 'fn'. +// +// RETURNS: A pointer to the registered benchmark. +internal::Benchmark* RegisterBenchmark(const std::string& name, + internal::Function* fn); + +#if defined(BENCHMARK_HAS_CXX11) +template +internal::Benchmark* RegisterBenchmark(const std::string& name, Lambda&& fn); +#endif + +// Remove all registered benchmarks. All pointers to previously registered +// benchmarks are invalidated. +BENCHMARK_EXPORT void ClearRegisteredBenchmarks(); + +namespace internal { +// The class used to hold all Benchmarks created from static function. +// (ie those created using the BENCHMARK(...) macros. +class BENCHMARK_EXPORT FunctionBenchmark : public Benchmark { + public: + FunctionBenchmark(const std::string& name, Function* func) + : Benchmark(name), func_(func) {} + + void Run(State& st) BENCHMARK_OVERRIDE; + + private: + Function* func_; +}; + +#ifdef BENCHMARK_HAS_CXX11 +template +class LambdaBenchmark : public Benchmark { + public: + void Run(State& st) BENCHMARK_OVERRIDE { lambda_(st); } + + private: + template + LambdaBenchmark(const std::string& name, OLambda&& lam) + : Benchmark(name), lambda_(std::forward(lam)) {} + + LambdaBenchmark(LambdaBenchmark const&) = delete; + + template // NOLINTNEXTLINE(readability-redundant-declaration) + friend Benchmark* ::benchmark::RegisterBenchmark(const std::string&, Lam&&); + + Lambda lambda_; +}; +#endif +} // namespace internal + +inline internal::Benchmark* RegisterBenchmark(const std::string& name, + internal::Function* fn) { + // FIXME: this should be a `std::make_unique<>()` but we don't have C++14. + // codechecker_intentional [cplusplus.NewDeleteLeaks] + return internal::RegisterBenchmarkInternal( + ::new internal::FunctionBenchmark(name, fn)); +} + +#ifdef BENCHMARK_HAS_CXX11 +template +internal::Benchmark* RegisterBenchmark(const std::string& name, Lambda&& fn) { + using BenchType = + internal::LambdaBenchmark::type>; + // FIXME: this should be a `std::make_unique<>()` but we don't have C++14. + // codechecker_intentional [cplusplus.NewDeleteLeaks] + return internal::RegisterBenchmarkInternal( + ::new BenchType(name, std::forward(fn))); +} +#endif + +#if defined(BENCHMARK_HAS_CXX11) && \ + (!defined(BENCHMARK_GCC_VERSION) || BENCHMARK_GCC_VERSION >= 409) +template +internal::Benchmark* RegisterBenchmark(const std::string& name, Lambda&& fn, + Args&&... args) { + return benchmark::RegisterBenchmark( + name, [=](benchmark::State& st) { fn(st, args...); }); +} +#else +#define BENCHMARK_HAS_NO_VARIADIC_REGISTER_BENCHMARK +#endif + +// The base class for all fixture tests. +class Fixture : public internal::Benchmark { + public: + Fixture() : internal::Benchmark("") {} + + void Run(State& st) BENCHMARK_OVERRIDE { + this->SetUp(st); + this->BenchmarkCase(st); + this->TearDown(st); + } + + // These will be deprecated ... + virtual void SetUp(const State&) {} + virtual void TearDown(const State&) {} + // ... In favor of these. + virtual void SetUp(State& st) { SetUp(const_cast(st)); } + virtual void TearDown(State& st) { TearDown(const_cast(st)); } + + protected: + virtual void BenchmarkCase(State&) = 0; +}; +} // namespace benchmark + +// ------------------------------------------------------ +// Macro to register benchmarks + +// Check that __COUNTER__ is defined and that __COUNTER__ increases by 1 +// every time it is expanded. X + 1 == X + 0 is used in case X is defined to be +// empty. If X is empty the expression becomes (+1 == +0). +#if defined(__COUNTER__) && (__COUNTER__ + 1 == __COUNTER__ + 0) +#define BENCHMARK_PRIVATE_UNIQUE_ID __COUNTER__ +#else +#define BENCHMARK_PRIVATE_UNIQUE_ID __LINE__ +#endif + +// Helpers for generating unique variable names +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_PRIVATE_NAME(...) \ + BENCHMARK_PRIVATE_CONCAT(benchmark_uniq_, BENCHMARK_PRIVATE_UNIQUE_ID, \ + __VA_ARGS__) +#else +#define BENCHMARK_PRIVATE_NAME(n) \ + BENCHMARK_PRIVATE_CONCAT(benchmark_uniq_, BENCHMARK_PRIVATE_UNIQUE_ID, n) +#endif // BENCHMARK_HAS_CXX11 + +#define BENCHMARK_PRIVATE_CONCAT(a, b, c) BENCHMARK_PRIVATE_CONCAT2(a, b, c) +#define BENCHMARK_PRIVATE_CONCAT2(a, b, c) a##b##c +// Helper for concatenation with macro name expansion +#define BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method) \ + BaseClass##_##Method##_Benchmark + +#define BENCHMARK_PRIVATE_DECLARE(n) \ + /* NOLINTNEXTLINE(misc-use-anonymous-namespace) */ \ + static ::benchmark::internal::Benchmark* BENCHMARK_PRIVATE_NAME(n) \ + BENCHMARK_UNUSED + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK(...) \ + BENCHMARK_PRIVATE_DECLARE(_benchmark_) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark(#__VA_ARGS__, \ + __VA_ARGS__))) +#else +#define BENCHMARK(n) \ + BENCHMARK_PRIVATE_DECLARE(n) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark(#n, n))) +#endif // BENCHMARK_HAS_CXX11 + +// Old-style macros +#define BENCHMARK_WITH_ARG(n, a) BENCHMARK(n)->Arg((a)) +#define BENCHMARK_WITH_ARG2(n, a1, a2) BENCHMARK(n)->Args({(a1), (a2)}) +#define BENCHMARK_WITH_UNIT(n, t) BENCHMARK(n)->Unit((t)) +#define BENCHMARK_RANGE(n, lo, hi) BENCHMARK(n)->Range((lo), (hi)) +#define BENCHMARK_RANGE2(n, l1, h1, l2, h2) \ + BENCHMARK(n)->RangePair({{(l1), (h1)}, {(l2), (h2)}}) + +#ifdef BENCHMARK_HAS_CXX11 + +// Register a benchmark which invokes the function specified by `func` +// with the additional arguments specified by `...`. +// +// For example: +// +// template ` +// void BM_takes_args(benchmark::State& state, ExtraArgs&&... extra_args) { +// [...] +//} +// /* Registers a benchmark named "BM_takes_args/int_string_test` */ +// BENCHMARK_CAPTURE(BM_takes_args, int_string_test, 42, std::string("abc")); +#define BENCHMARK_CAPTURE(func, test_case_name, ...) \ + BENCHMARK_PRIVATE_DECLARE(_benchmark_) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark( \ + #func "/" #test_case_name, \ + [](::benchmark::State& st) { func(st, __VA_ARGS__); }))) + +#endif // BENCHMARK_HAS_CXX11 + +// This will register a benchmark for a templatized function. For example: +// +// template +// void BM_Foo(int iters); +// +// BENCHMARK_TEMPLATE(BM_Foo, 1); +// +// will register BM_Foo<1> as a benchmark. +#define BENCHMARK_TEMPLATE1(n, a) \ + BENCHMARK_PRIVATE_DECLARE(n) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark(#n "<" #a ">", n))) + +#define BENCHMARK_TEMPLATE2(n, a, b) \ + BENCHMARK_PRIVATE_DECLARE(n) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark(#n "<" #a "," #b ">", \ + n))) + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_TEMPLATE(n, ...) \ + BENCHMARK_PRIVATE_DECLARE(n) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark( \ + #n "<" #__VA_ARGS__ ">", n<__VA_ARGS__>))) +#else +#define BENCHMARK_TEMPLATE(n, a) BENCHMARK_TEMPLATE1(n, a) +#endif + +#ifdef BENCHMARK_HAS_CXX11 +// This will register a benchmark for a templatized function, +// with the additional arguments specified by `...`. +// +// For example: +// +// template ` +// void BM_takes_args(benchmark::State& state, ExtraArgs&&... extra_args) { +// [...] +//} +// /* Registers a benchmark named "BM_takes_args/int_string_test` */ +// BENCHMARK_TEMPLATE1_CAPTURE(BM_takes_args, void, int_string_test, 42, +// std::string("abc")); +#define BENCHMARK_TEMPLATE1_CAPTURE(func, a, test_case_name, ...) \ + BENCHMARK_CAPTURE(func, test_case_name, __VA_ARGS__) + +#define BENCHMARK_TEMPLATE2_CAPTURE(func, a, b, test_case_name, ...) \ + BENCHMARK_PRIVATE_DECLARE(func) = \ + (::benchmark::internal::RegisterBenchmarkInternal( \ + new ::benchmark::internal::FunctionBenchmark( \ + #func "<" #a "," #b ">" \ + "/" #test_case_name, \ + [](::benchmark::State& st) { func(st, __VA_ARGS__); }))) +#endif // BENCHMARK_HAS_CXX11 + +#define BENCHMARK_PRIVATE_DECLARE_F(BaseClass, Method) \ + class BaseClass##_##Method##_Benchmark : public BaseClass { \ + public: \ + BaseClass##_##Method##_Benchmark() { \ + this->SetName(#BaseClass "/" #Method); \ + } \ + \ + protected: \ + void BenchmarkCase(::benchmark::State&) BENCHMARK_OVERRIDE; \ + }; + +#define BENCHMARK_TEMPLATE1_PRIVATE_DECLARE_F(BaseClass, Method, a) \ + class BaseClass##_##Method##_Benchmark : public BaseClass { \ + public: \ + BaseClass##_##Method##_Benchmark() { \ + this->SetName(#BaseClass "<" #a ">/" #Method); \ + } \ + \ + protected: \ + void BenchmarkCase(::benchmark::State&) BENCHMARK_OVERRIDE; \ + }; + +#define BENCHMARK_TEMPLATE2_PRIVATE_DECLARE_F(BaseClass, Method, a, b) \ + class BaseClass##_##Method##_Benchmark : public BaseClass { \ + public: \ + BaseClass##_##Method##_Benchmark() { \ + this->SetName(#BaseClass "<" #a "," #b ">/" #Method); \ + } \ + \ + protected: \ + void BenchmarkCase(::benchmark::State&) BENCHMARK_OVERRIDE; \ + }; + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_TEMPLATE_PRIVATE_DECLARE_F(BaseClass, Method, ...) \ + class BaseClass##_##Method##_Benchmark : public BaseClass<__VA_ARGS__> { \ + public: \ + BaseClass##_##Method##_Benchmark() { \ + this->SetName(#BaseClass "<" #__VA_ARGS__ ">/" #Method); \ + } \ + \ + protected: \ + void BenchmarkCase(::benchmark::State&) BENCHMARK_OVERRIDE; \ + }; +#else +#define BENCHMARK_TEMPLATE_PRIVATE_DECLARE_F(n, a) \ + BENCHMARK_TEMPLATE1_PRIVATE_DECLARE_F(n, a) +#endif + +#define BENCHMARK_DEFINE_F(BaseClass, Method) \ + BENCHMARK_PRIVATE_DECLARE_F(BaseClass, Method) \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#define BENCHMARK_TEMPLATE1_DEFINE_F(BaseClass, Method, a) \ + BENCHMARK_TEMPLATE1_PRIVATE_DECLARE_F(BaseClass, Method, a) \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#define BENCHMARK_TEMPLATE2_DEFINE_F(BaseClass, Method, a, b) \ + BENCHMARK_TEMPLATE2_PRIVATE_DECLARE_F(BaseClass, Method, a, b) \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_TEMPLATE_DEFINE_F(BaseClass, Method, ...) \ + BENCHMARK_TEMPLATE_PRIVATE_DECLARE_F(BaseClass, Method, __VA_ARGS__) \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase +#else +#define BENCHMARK_TEMPLATE_DEFINE_F(BaseClass, Method, a) \ + BENCHMARK_TEMPLATE1_DEFINE_F(BaseClass, Method, a) +#endif + +#define BENCHMARK_REGISTER_F(BaseClass, Method) \ + BENCHMARK_PRIVATE_REGISTER_F(BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)) + +#define BENCHMARK_PRIVATE_REGISTER_F(TestName) \ + BENCHMARK_PRIVATE_DECLARE(TestName) = \ + (::benchmark::internal::RegisterBenchmarkInternal(new TestName())) + +// This macro will define and register a benchmark within a fixture class. +#define BENCHMARK_F(BaseClass, Method) \ + BENCHMARK_PRIVATE_DECLARE_F(BaseClass, Method) \ + BENCHMARK_REGISTER_F(BaseClass, Method); \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#define BENCHMARK_TEMPLATE1_F(BaseClass, Method, a) \ + BENCHMARK_TEMPLATE1_PRIVATE_DECLARE_F(BaseClass, Method, a) \ + BENCHMARK_REGISTER_F(BaseClass, Method); \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#define BENCHMARK_TEMPLATE2_F(BaseClass, Method, a, b) \ + BENCHMARK_TEMPLATE2_PRIVATE_DECLARE_F(BaseClass, Method, a, b) \ + BENCHMARK_REGISTER_F(BaseClass, Method); \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase + +#ifdef BENCHMARK_HAS_CXX11 +#define BENCHMARK_TEMPLATE_F(BaseClass, Method, ...) \ + BENCHMARK_TEMPLATE_PRIVATE_DECLARE_F(BaseClass, Method, __VA_ARGS__) \ + BENCHMARK_REGISTER_F(BaseClass, Method); \ + void BENCHMARK_PRIVATE_CONCAT_NAME(BaseClass, Method)::BenchmarkCase +#else +#define BENCHMARK_TEMPLATE_F(BaseClass, Method, a) \ + BENCHMARK_TEMPLATE1_F(BaseClass, Method, a) +#endif + +// Helper macro to create a main routine in a test that runs the benchmarks +// Note the workaround for Hexagon simulator passing argc != 0, argv = NULL. +#define BENCHMARK_MAIN() \ + int main(int argc, char** argv) { \ + char arg0_default[] = "benchmark"; \ + char* args_default = arg0_default; \ + if (!argv) { \ + argc = 1; \ + argv = &args_default; \ + } \ + ::benchmark::Initialize(&argc, argv); \ + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; \ + ::benchmark::RunSpecifiedBenchmarks(); \ + ::benchmark::Shutdown(); \ + return 0; \ + } \ + int main(int, char**) + +// ------------------------------------------------------ +// Benchmark Reporters + +namespace benchmark { + +struct BENCHMARK_EXPORT CPUInfo { + struct CacheInfo { + std::string type; + int level; + int size; + int num_sharing; + }; + + enum Scaling { UNKNOWN, ENABLED, DISABLED }; + + int num_cpus; + Scaling scaling; + double cycles_per_second; + std::vector caches; + std::vector load_avg; + + static const CPUInfo& Get(); + + private: + CPUInfo(); + BENCHMARK_DISALLOW_COPY_AND_ASSIGN(CPUInfo); +}; + +// Adding Struct for System Information +struct BENCHMARK_EXPORT SystemInfo { + std::string name; + static const SystemInfo& Get(); + + private: + SystemInfo(); + BENCHMARK_DISALLOW_COPY_AND_ASSIGN(SystemInfo); +}; + +// BenchmarkName contains the components of the Benchmark's name +// which allows individual fields to be modified or cleared before +// building the final name using 'str()'. +struct BENCHMARK_EXPORT BenchmarkName { + std::string function_name; + std::string args; + std::string min_time; + std::string min_warmup_time; + std::string iterations; + std::string repetitions; + std::string time_type; + std::string threads; + + // Return the full name of the benchmark with each non-empty + // field separated by a '/' + std::string str() const; +}; + +// Interface for custom benchmark result printers. +// By default, benchmark reports are printed to stdout. However an application +// can control the destination of the reports by calling +// RunSpecifiedBenchmarks and passing it a custom reporter object. +// The reporter object must implement the following interface. +class BENCHMARK_EXPORT BenchmarkReporter { + public: + struct Context { + CPUInfo const& cpu_info; + SystemInfo const& sys_info; + // The number of chars in the longest benchmark name. + size_t name_field_width; + static const char* executable_name; + Context(); + }; + + struct BENCHMARK_EXPORT Run { + static const int64_t no_repetition_index = -1; + enum RunType { RT_Iteration, RT_Aggregate }; + + Run() + : run_type(RT_Iteration), + aggregate_unit(kTime), + skipped(internal::NotSkipped), + iterations(1), + threads(1), + time_unit(GetDefaultTimeUnit()), + real_accumulated_time(0), + cpu_accumulated_time(0), + max_heapbytes_used(0), + use_real_time_for_initial_big_o(false), + complexity(oNone), + complexity_lambda(), + complexity_n(0), + report_big_o(false), + report_rms(false), + memory_result(NULL), + allocs_per_iter(0.0) {} + + std::string benchmark_name() const; + BenchmarkName run_name; + int64_t family_index; + int64_t per_family_instance_index; + RunType run_type; + std::string aggregate_name; + StatisticUnit aggregate_unit; + std::string report_label; // Empty if not set by benchmark. + internal::Skipped skipped; + std::string skip_message; + + IterationCount iterations; + int64_t threads; + int64_t repetition_index; + int64_t repetitions; + TimeUnit time_unit; + double real_accumulated_time; + double cpu_accumulated_time; + + // Return a value representing the real time per iteration in the unit + // specified by 'time_unit'. + // NOTE: If 'iterations' is zero the returned value represents the + // accumulated time. + double GetAdjustedRealTime() const; + + // Return a value representing the cpu time per iteration in the unit + // specified by 'time_unit'. + // NOTE: If 'iterations' is zero the returned value represents the + // accumulated time. + double GetAdjustedCPUTime() const; + + // This is set to 0.0 if memory tracing is not enabled. + double max_heapbytes_used; + + // By default Big-O is computed for CPU time, but that is not what you want + // to happen when manual time was requested, which is stored as real time. + bool use_real_time_for_initial_big_o; + + // Keep track of arguments to compute asymptotic complexity + BigO complexity; + BigOFunc* complexity_lambda; + ComplexityN complexity_n; + + // what statistics to compute from the measurements + const std::vector* statistics; + + // Inform print function whether the current run is a complexity report + bool report_big_o; + bool report_rms; + + UserCounters counters; + + // Memory metrics. + const MemoryManager::Result* memory_result; + double allocs_per_iter; + }; + + struct PerFamilyRunReports { + PerFamilyRunReports() : num_runs_total(0), num_runs_done(0) {} + + // How many runs will all instances of this benchmark perform? + int num_runs_total; + + // How many runs have happened already? + int num_runs_done; + + // The reports about (non-errneous!) runs of this family. + std::vector Runs; + }; + + // Construct a BenchmarkReporter with the output stream set to 'std::cout' + // and the error stream set to 'std::cerr' + BenchmarkReporter(); + + // Called once for every suite of benchmarks run. + // The parameter "context" contains information that the + // reporter may wish to use when generating its report, for example the + // platform under which the benchmarks are running. The benchmark run is + // never started if this function returns false, allowing the reporter + // to skip runs based on the context information. + virtual bool ReportContext(const Context& context) = 0; + + // Called once for each group of benchmark runs, gives information about + // the configurations of the runs. + virtual void ReportRunsConfig(double /*min_time*/, + bool /*has_explicit_iters*/, + IterationCount /*iters*/) {} + + // Called once for each group of benchmark runs, gives information about + // cpu-time and heap memory usage during the benchmark run. If the group + // of runs contained more than two entries then 'report' contains additional + // elements representing the mean and standard deviation of those runs. + // Additionally if this group of runs was the last in a family of benchmarks + // 'reports' contains additional entries representing the asymptotic + // complexity and RMS of that benchmark family. + virtual void ReportRuns(const std::vector& report) = 0; + + // Called once and only once after ever group of benchmarks is run and + // reported. + virtual void Finalize() {} + + // REQUIRES: The object referenced by 'out' is valid for the lifetime + // of the reporter. + void SetOutputStream(std::ostream* out) { + assert(out); + output_stream_ = out; + } + + // REQUIRES: The object referenced by 'err' is valid for the lifetime + // of the reporter. + void SetErrorStream(std::ostream* err) { + assert(err); + error_stream_ = err; + } + + std::ostream& GetOutputStream() const { return *output_stream_; } + + std::ostream& GetErrorStream() const { return *error_stream_; } + + virtual ~BenchmarkReporter(); + + // Write a human readable string to 'out' representing the specified + // 'context'. + // REQUIRES: 'out' is non-null. + static void PrintBasicContext(std::ostream* out, Context const& context); + + private: + std::ostream* output_stream_; + std::ostream* error_stream_; +}; + +// Simple reporter that outputs benchmark data to the console. This is the +// default reporter used by RunSpecifiedBenchmarks(). +class BENCHMARK_EXPORT ConsoleReporter : public BenchmarkReporter { + public: + enum OutputOptions { + OO_None = 0, + OO_Color = 1, + OO_Tabular = 2, + OO_ColorTabular = OO_Color | OO_Tabular, + OO_Defaults = OO_ColorTabular + }; + explicit ConsoleReporter(OutputOptions opts_ = OO_Defaults) + : output_options_(opts_), name_field_width_(0), printed_header_(false) {} + + bool ReportContext(const Context& context) BENCHMARK_OVERRIDE; + void ReportRuns(const std::vector& reports) BENCHMARK_OVERRIDE; + + protected: + virtual void PrintRunData(const Run& report); + virtual void PrintHeader(const Run& report); + + OutputOptions output_options_; + size_t name_field_width_; + UserCounters prev_counters_; + bool printed_header_; +}; + +class BENCHMARK_EXPORT JSONReporter : public BenchmarkReporter { + public: + JSONReporter() : first_report_(true) {} + bool ReportContext(const Context& context) BENCHMARK_OVERRIDE; + void ReportRuns(const std::vector& reports) BENCHMARK_OVERRIDE; + void Finalize() BENCHMARK_OVERRIDE; + + private: + void PrintRunData(const Run& report); + + bool first_report_; +}; + +class BENCHMARK_EXPORT BENCHMARK_DEPRECATED_MSG( + "The CSV Reporter will be removed in a future release") CSVReporter + : public BenchmarkReporter { + public: + CSVReporter() : printed_header_(false) {} + bool ReportContext(const Context& context) BENCHMARK_OVERRIDE; + void ReportRuns(const std::vector& reports) BENCHMARK_OVERRIDE; + + private: + void PrintRunData(const Run& report); + + bool printed_header_; + std::set user_counter_names_; +}; + +inline const char* GetTimeUnitString(TimeUnit unit) { + switch (unit) { + case kSecond: + return "s"; + case kMillisecond: + return "ms"; + case kMicrosecond: + return "us"; + case kNanosecond: + return "ns"; + } + BENCHMARK_UNREACHABLE(); +} + +inline double GetTimeUnitMultiplier(TimeUnit unit) { + switch (unit) { + case kSecond: + return 1; + case kMillisecond: + return 1e3; + case kMicrosecond: + return 1e6; + case kNanosecond: + return 1e9; + } + BENCHMARK_UNREACHABLE(); +} + +// Creates a list of integer values for the given range and multiplier. +// This can be used together with ArgsProduct() to allow multiple ranges +// with different multipliers. +// Example: +// ArgsProduct({ +// CreateRange(0, 1024, /*multi=*/32), +// CreateRange(0, 100, /*multi=*/4), +// CreateDenseRange(0, 4, /*step=*/1), +// }); +BENCHMARK_EXPORT +std::vector CreateRange(int64_t lo, int64_t hi, int multi); + +// Creates a list of integer values for the given range and step. +BENCHMARK_EXPORT +std::vector CreateDenseRange(int64_t start, int64_t limit, int step); + +} // namespace benchmark + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#endif // BENCHMARK_BENCHMARK_H_ diff --git a/third_party/benchmark/include/benchmark/export.h b/third_party/benchmark/include/benchmark/export.h new file mode 100644 index 0000000..f96f859 --- /dev/null +++ b/third_party/benchmark/include/benchmark/export.h @@ -0,0 +1,47 @@ +#ifndef BENCHMARK_EXPORT_H +#define BENCHMARK_EXPORT_H + +#if defined(_WIN32) +#define EXPORT_ATTR __declspec(dllexport) +#define IMPORT_ATTR __declspec(dllimport) +#define NO_EXPORT_ATTR +#define DEPRECATED_ATTR __declspec(deprecated) +#else // _WIN32 +#define EXPORT_ATTR __attribute__((visibility("default"))) +#define IMPORT_ATTR __attribute__((visibility("default"))) +#define NO_EXPORT_ATTR __attribute__((visibility("hidden"))) +#define DEPRECATE_ATTR __attribute__((__deprecated__)) +#endif // _WIN32 + +#ifdef BENCHMARK_STATIC_DEFINE +#define BENCHMARK_EXPORT +#define BENCHMARK_NO_EXPORT +#else // BENCHMARK_STATIC_DEFINE +#ifndef BENCHMARK_EXPORT +#ifdef benchmark_EXPORTS +/* We are building this library */ +#define BENCHMARK_EXPORT EXPORT_ATTR +#else // benchmark_EXPORTS +/* We are using this library */ +#define BENCHMARK_EXPORT IMPORT_ATTR +#endif // benchmark_EXPORTS +#endif // !BENCHMARK_EXPORT + +#ifndef BENCHMARK_NO_EXPORT +#define BENCHMARK_NO_EXPORT NO_EXPORT_ATTR +#endif // !BENCHMARK_NO_EXPORT +#endif // BENCHMARK_STATIC_DEFINE + +#ifndef BENCHMARK_DEPRECATED +#define BENCHMARK_DEPRECATED DEPRECATE_ATTR +#endif // BENCHMARK_DEPRECATED + +#ifndef BENCHMARK_DEPRECATED_EXPORT +#define BENCHMARK_DEPRECATED_EXPORT BENCHMARK_EXPORT BENCHMARK_DEPRECATED +#endif // BENCHMARK_DEPRECATED_EXPORT + +#ifndef BENCHMARK_DEPRECATED_NO_EXPORT +#define BENCHMARK_DEPRECATED_NO_EXPORT BENCHMARK_NO_EXPORT BENCHMARK_DEPRECATED +#endif // BENCHMARK_DEPRECATED_EXPORT + +#endif /* BENCHMARK_EXPORT_H */ diff --git a/third_party/benchmark/pyproject.toml b/third_party/benchmark/pyproject.toml new file mode 100644 index 0000000..14f173f --- /dev/null +++ b/third_party/benchmark/pyproject.toml @@ -0,0 +1,77 @@ +[build-system] +requires = ["setuptools<73"] +build-backend = "setuptools.build_meta" + +[project] +name = "google_benchmark" +description = "A library to benchmark code snippets." +requires-python = ">=3.10" +license = { file = "LICENSE" } +keywords = ["benchmark"] + +authors = [{ name = "Google", email = "benchmark-discuss@googlegroups.com" }] + +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Testing", + "Topic :: System :: Benchmark", +] + +dynamic = ["readme", "version"] + +dependencies = ["absl-py>=0.7.1"] + +[project.optional-dependencies] +dev = ["pre-commit>=3.3.3"] + +[project.urls] +Homepage = "https://github.com/google/benchmark" +Documentation = "https://github.com/google/benchmark/tree/main/docs" +Repository = "https://github.com/google/benchmark.git" +Discord = "https://discord.gg/cz7UX7wKC2" + +[tool.setuptools] +package-dir = { "" = "bindings/python" } +zip-safe = false + +[tool.setuptools.packages.find] +where = ["bindings/python"] + +[tool.setuptools.dynamic] +readme = { file = "README.md", content-type = "text/markdown" } +version = { attr = "google_benchmark.__version__" } + +[tool.mypy] +check_untyped_defs = true +disallow_incomplete_defs = true +pretty = true +python_version = "3.11" +strict_optional = false +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["yaml"] +ignore_missing_imports = true + +[tool.ruff] +# explicitly tell ruff the source directory to correctly identify first-party package. +src = ["bindings/python"] + +line-length = 80 +target-version = "py311" + +[tool.ruff.lint] +# Enable pycodestyle (`E`, `W`), Pyflakes (`F`), and isort (`I`) codes by default. +select = ["E", "F", "I", "W"] +ignore = [ + "E501", # line too long +] + +[tool.ruff.lint.isort] +combine-as-imports = true diff --git a/third_party/benchmark/setup.py b/third_party/benchmark/setup.py new file mode 100644 index 0000000..69cc49d --- /dev/null +++ b/third_party/benchmark/setup.py @@ -0,0 +1,169 @@ +import contextlib +import os +import platform +import re +import shutil +import sys +from pathlib import Path +from typing import Any, Generator + +import setuptools +from setuptools.command import build_ext + +IS_WINDOWS = platform.system() == "Windows" +IS_MAC = platform.system() == "Darwin" +IS_LINUX = platform.system() == "Linux" + +# hardcoded SABI-related options. Requires that each Python interpreter +# (hermetic or not) participating is of the same major-minor version. +py_limited_api = sys.version_info >= (3, 12) +options = {"bdist_wheel": {"py_limited_api": "cp312"}} if py_limited_api else {} + + +def is_cibuildwheel() -> bool: + return os.getenv("CIBUILDWHEEL") is not None + + +@contextlib.contextmanager +def _maybe_patch_toolchains() -> Generator[None, None, None]: + """ + Patch rules_python toolchains to ignore root user error + when run in a Docker container on Linux in cibuildwheel. + """ + + def fmt_toolchain_args(matchobj): + suffix = "ignore_root_user_error = True" + callargs = matchobj.group(1) + # toolchain def is broken over multiple lines + if callargs.endswith("\n"): + callargs = callargs + " " + suffix + ",\n" + # toolchain def is on one line. + else: + callargs = callargs + ", " + suffix + return "python.toolchain(" + callargs + ")" + + CIBW_LINUX = is_cibuildwheel() and IS_LINUX + module_bazel = Path("MODULE.bazel") + content: str = module_bazel.read_text() + try: + if CIBW_LINUX: + module_bazel.write_text( + re.sub( + r"python.toolchain\(([\w\"\s,.=]*)\)", + fmt_toolchain_args, + content, + ) + ) + yield + finally: + if CIBW_LINUX: + module_bazel.write_text(content) + + +class BazelExtension(setuptools.Extension): + """A C/C++ extension that is defined as a Bazel BUILD target.""" + + def __init__(self, name: str, bazel_target: str, **kwargs: Any): + super().__init__(name=name, sources=[], **kwargs) + + self.bazel_target = bazel_target + stripped_target = bazel_target.split("//")[-1] + self.relpath, self.target_name = stripped_target.split(":") + + +class BuildBazelExtension(build_ext.build_ext): + """A command that runs Bazel to build a C/C++ extension.""" + + def run(self): + for ext in self.extensions: + self.bazel_build(ext) + super().run() + # explicitly call `bazel shutdown` for graceful exit + self.spawn(["bazel", "shutdown"]) + + def copy_extensions_to_source(self): + """ + Copy generated extensions into the source tree. + This is done in the ``bazel_build`` method, so it's not necessary to + do again in the `build_ext` base class. + """ + pass + + def bazel_build(self, ext: BazelExtension) -> None: + """Runs the bazel build to create the package.""" + temp_path = Path(self.build_temp) + + # We round to the minor version, which makes rules_python + # look up the latest available patch version internally. + python_version = "{0}.{1}".format(*sys.version_info[:2]) + + bazel_argv = [ + "bazel", + "run", + ext.bazel_target, + f"--symlink_prefix={temp_path / 'bazel-'}", + f"--compilation_mode={'dbg' if self.debug else 'opt'}", + # C++17 is required by nanobind + f"--cxxopt={'/std:c++17' if IS_WINDOWS else '-std=c++17'}", + f"--@rules_python//python/config_settings:python_version={python_version}", + ] + + if ext.py_limited_api: + bazel_argv += ["--@nanobind_bazel//:py-limited-api=cp312"] + + if IS_WINDOWS: + # Link with python*.lib. + for library_dir in self.library_dirs: + bazel_argv.append("--linkopt=/LIBPATH:" + library_dir) + elif IS_MAC: + # C++17 needs macOS 10.14 at minimum + bazel_argv.append("--macos_minimum_os=10.14") + + with _maybe_patch_toolchains(): + self.spawn(bazel_argv) + + if IS_WINDOWS: + suffix = ".pyd" + else: + suffix = ".abi3.so" if ext.py_limited_api else ".so" + + # copy the Bazel build artifacts into setuptools' libdir, + # from where the wheel is built. + pkgname = "google_benchmark" + pythonroot = Path("bindings") / "python" / "google_benchmark" + srcdir = temp_path / "bazel-bin" / pythonroot + libdir = Path(self.build_lib) / pkgname + for root, dirs, files in os.walk(srcdir, topdown=True): + # exclude runfiles directories and children. + dirs[:] = [d for d in dirs if "runfiles" not in d] + + for f in files: + fp = Path(f) + should_copy = False + # we do not want the bare .so file included + # when building for ABI3, so we require a + # full and exact match on the file extension. + if "".join(fp.suffixes) == suffix: + should_copy = True + elif fp.suffix == ".pyi": + should_copy = True + elif Path(root) == srcdir and f == "py.typed": + # copy py.typed, but only at the package root. + should_copy = True + + if should_copy: + shutil.copyfile(root / fp, libdir / fp) + + +setuptools.setup( + cmdclass=dict(build_ext=BuildBazelExtension), + package_data={"google_benchmark": ["py.typed", "*.pyi"]}, + ext_modules=[ + BazelExtension( + name="google_benchmark._benchmark", + bazel_target="//bindings/python/google_benchmark:benchmark_stubgen", + py_limited_api=py_limited_api, + ) + ], + options=options, +) diff --git a/third_party/benchmark/src/CMakeLists.txt b/third_party/benchmark/src/CMakeLists.txt new file mode 100644 index 0000000..32126c0 --- /dev/null +++ b/third_party/benchmark/src/CMakeLists.txt @@ -0,0 +1,180 @@ +#Allow the source files to find headers in src / +include(GNUInstallDirs) +include_directories(${PROJECT_SOURCE_DIR}/src) + +if (DEFINED BENCHMARK_CXX_LINKER_FLAGS) + list(APPEND CMAKE_SHARED_LINKER_FLAGS ${BENCHMARK_CXX_LINKER_FLAGS}) + list(APPEND CMAKE_MODULE_LINKER_FLAGS ${BENCHMARK_CXX_LINKER_FLAGS}) +endif() + +file(GLOB + SOURCE_FILES + *.cc + ${PROJECT_SOURCE_DIR}/include/benchmark/*.h + ${CMAKE_CURRENT_SOURCE_DIR}/*.h) +file(GLOB BENCHMARK_MAIN "benchmark_main.cc") +foreach(item ${BENCHMARK_MAIN}) + list(REMOVE_ITEM SOURCE_FILES "${item}") +endforeach() + +add_library(benchmark ${SOURCE_FILES}) +add_library(benchmark::benchmark ALIAS benchmark) +set_target_properties(benchmark PROPERTIES + OUTPUT_NAME "benchmark" + VERSION ${GENERIC_LIB_VERSION} + SOVERSION ${GENERIC_LIB_SOVERSION} +) +target_include_directories(benchmark PUBLIC + $ +) + +set_property( + SOURCE benchmark.cc + APPEND + PROPERTY COMPILE_DEFINITIONS + BENCHMARK_VERSION="${VERSION}" +) + +# libpfm, if available +if (PFM_FOUND) + target_link_libraries(benchmark PRIVATE PFM::libpfm) + target_compile_definitions(benchmark PRIVATE -DHAVE_LIBPFM) +endif() + +# pthread affinity, if available +if(HAVE_PTHREAD_AFFINITY) + target_compile_definitions(benchmark PRIVATE -DBENCHMARK_HAS_PTHREAD_AFFINITY) +endif() + +# Link threads. +target_link_libraries(benchmark PRIVATE Threads::Threads) + +target_link_libraries(benchmark PRIVATE ${BENCHMARK_CXX_LIBRARIES}) + +if(HAVE_LIB_RT) + target_link_libraries(benchmark PRIVATE rt) +endif(HAVE_LIB_RT) + + +# We need extra libraries on Windows +if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") + target_link_libraries(benchmark PRIVATE shlwapi) +endif() + +# We need extra libraries on Solaris +if(${CMAKE_SYSTEM_NAME} MATCHES "SunOS") + target_link_libraries(benchmark PRIVATE kstat) + set(BENCHMARK_PRIVATE_LINK_LIBRARIES -lkstat) +endif() + +if (NOT BUILD_SHARED_LIBS) + target_compile_definitions(benchmark PUBLIC -DBENCHMARK_STATIC_DEFINE) +endif() + +# Benchmark main library +add_library(benchmark_main "benchmark_main.cc") +add_library(benchmark::benchmark_main ALIAS benchmark_main) +set_target_properties(benchmark_main PROPERTIES + OUTPUT_NAME "benchmark_main" + VERSION ${GENERIC_LIB_VERSION} + SOVERSION ${GENERIC_LIB_SOVERSION} + DEFINE_SYMBOL benchmark_EXPORTS +) +target_link_libraries(benchmark_main PUBLIC benchmark::benchmark) + +set(generated_dir "${PROJECT_BINARY_DIR}") + +set(version_config "${generated_dir}/${PROJECT_NAME}ConfigVersion.cmake") +set(project_config "${generated_dir}/${PROJECT_NAME}Config.cmake") +set(pkg_config "${generated_dir}/${PROJECT_NAME}.pc") +set(pkg_config_main "${generated_dir}/${PROJECT_NAME}_main.pc") +set(targets_to_export benchmark benchmark_main) +set(targets_export_name "${PROJECT_NAME}Targets") + +set(namespace "${PROJECT_NAME}::") + +include(CMakePackageConfigHelpers) + +configure_package_config_file ( + ${PROJECT_SOURCE_DIR}/cmake/Config.cmake.in + ${project_config} + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME} + NO_SET_AND_CHECK_MACRO + NO_CHECK_REQUIRED_COMPONENTS_MACRO +) +write_basic_package_version_file( + "${version_config}" VERSION ${GENERIC_LIB_VERSION} COMPATIBILITY SameMajorVersion +) + +configure_file("${PROJECT_SOURCE_DIR}/cmake/benchmark.pc.in" "${pkg_config}" @ONLY) +configure_file("${PROJECT_SOURCE_DIR}/cmake/benchmark_main.pc.in" "${pkg_config_main}" @ONLY) + +export ( + TARGETS ${targets_to_export} + NAMESPACE "${namespace}" + FILE ${generated_dir}/${targets_export_name}.cmake +) + +if (BENCHMARK_ENABLE_INSTALL) + # Install target (will install the library to specified CMAKE_INSTALL_PREFIX variable) + install( + TARGETS ${targets_to_export} + EXPORT ${targets_export_name} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + + install( + DIRECTORY "${PROJECT_SOURCE_DIR}/include/benchmark" + "${PROJECT_BINARY_DIR}/include/benchmark" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} + FILES_MATCHING PATTERN "*.*h") + + install( + FILES "${project_config}" "${version_config}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}") + + install( + FILES "${pkg_config}" "${pkg_config_main}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + + install( + EXPORT "${targets_export_name}" + NAMESPACE "${namespace}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}") +endif() + +if (BENCHMARK_ENABLE_DOXYGEN) + find_package(Doxygen REQUIRED) + set(DOXYGEN_QUIET YES) + set(DOXYGEN_RECURSIVE YES) + set(DOXYGEN_GENERATE_HTML YES) + set(DOXYGEN_GENERATE_MAN NO) + set(DOXYGEN_MARKDOWN_SUPPORT YES) + set(DOXYGEN_BUILTIN_STL_SUPPORT YES) + set(DOXYGEN_EXTRACT_PACKAGE YES) + set(DOXYGEN_EXTRACT_STATIC YES) + set(DOXYGEN_SHOW_INCLUDE_FILES YES) + set(DOXYGEN_BINARY_TOC YES) + set(DOXYGEN_TOC_EXPAND YES) + set(DOXYGEN_USE_MDFILE_AS_MAINPAGE "index.md") + doxygen_add_docs(benchmark_doxygen + docs + include + src + ALL + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMENT "Building documentation with Doxygen.") + if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + install( + DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/html/" + DESTINATION ${CMAKE_INSTALL_DOCDIR}) + endif() +else() + if (BENCHMARK_ENABLE_INSTALL AND BENCHMARK_INSTALL_DOCS) + install( + DIRECTORY "${PROJECT_SOURCE_DIR}/docs/" + DESTINATION ${CMAKE_INSTALL_DOCDIR}) + endif() +endif() diff --git a/third_party/benchmark/src/arraysize.h b/third_party/benchmark/src/arraysize.h new file mode 100644 index 0000000..51a50f2 --- /dev/null +++ b/third_party/benchmark/src/arraysize.h @@ -0,0 +1,33 @@ +#ifndef BENCHMARK_ARRAYSIZE_H_ +#define BENCHMARK_ARRAYSIZE_H_ + +#include "internal_macros.h" + +namespace benchmark { +namespace internal { +// The arraysize(arr) macro returns the # of elements in an array arr. +// The expression is a compile-time constant, and therefore can be +// used in defining new arrays, for example. If you use arraysize on +// a pointer by mistake, you will get a compile-time error. +// + +// This template function declaration is used in defining arraysize. +// Note that the function doesn't need an implementation, as we only +// use its type. +template +char (&ArraySizeHelper(T (&array)[N]))[N]; + +// That gcc wants both of these prototypes seems mysterious. VC, for +// its part, can't decide which to use (another mystery). Matching of +// template overloads: the final frontier. +#ifndef COMPILER_MSVC +template +char (&ArraySizeHelper(const T (&array)[N]))[N]; +#endif + +#define arraysize(array) (sizeof(::benchmark::internal::ArraySizeHelper(array))) + +} // end namespace internal +} // end namespace benchmark + +#endif // BENCHMARK_ARRAYSIZE_H_ diff --git a/third_party/benchmark/src/benchmark.cc b/third_party/benchmark/src/benchmark.cc new file mode 100644 index 0000000..0ea90ae --- /dev/null +++ b/third_party/benchmark/src/benchmark.cc @@ -0,0 +1,832 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include "benchmark_api_internal.h" +#include "benchmark_runner.h" +#include "internal_macros.h" + +#ifndef BENCHMARK_OS_WINDOWS +#if !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +#include +#endif +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check.h" +#include "colorprint.h" +#include "commandlineflags.h" +#include "complexity.h" +#include "counter.h" +#include "internal_macros.h" +#include "log.h" +#include "mutex.h" +#include "perf_counters.h" +#include "re.h" +#include "statistics.h" +#include "string_util.h" +#include "thread_manager.h" +#include "thread_timer.h" + +namespace benchmark { +// Print a list of benchmarks. This option overrides all other options. +BM_DEFINE_bool(benchmark_list_tests, false); + +// A regular expression that specifies the set of benchmarks to execute. If +// this flag is empty, or if this flag is the string \"all\", all benchmarks +// linked into the binary are run. +BM_DEFINE_string(benchmark_filter, ""); + +// Specification of how long to run the benchmark. +// +// It can be either an exact number of iterations (specified as `x`), +// or a minimum number of seconds (specified as `s`). If the latter +// format (ie., min seconds) is used, the system may run the benchmark longer +// until the results are considered significant. +// +// For backward compatibility, the `s` suffix may be omitted, in which case, +// the specified number is interpreted as the number of seconds. +// +// For cpu-time based tests, this is the lower bound +// on the total cpu time used by all threads that make up the test. For +// real-time based tests, this is the lower bound on the elapsed time of the +// benchmark execution, regardless of number of threads. +BM_DEFINE_string(benchmark_min_time, kDefaultMinTimeStr); + +// Minimum number of seconds a benchmark should be run before results should be +// taken into account. This e.g can be necessary for benchmarks of code which +// needs to fill some form of cache before performance is of interest. +// Note: results gathered within this period are discarded and not used for +// reported result. +BM_DEFINE_double(benchmark_min_warmup_time, 0.0); + +// The number of runs of each benchmark. If greater than 1, the mean and +// standard deviation of the runs will be reported. +BM_DEFINE_int32(benchmark_repetitions, 1); + +// If enabled, forces each benchmark to execute exactly one iteration and one +// repetition, bypassing any configured +// MinTime()/MinWarmUpTime()/Iterations()/Repetitions() +BM_DEFINE_bool(benchmark_dry_run, false); + +// If set, enable random interleaving of repetitions of all benchmarks. +// See http://github.com/google/benchmark/issues/1051 for details. +BM_DEFINE_bool(benchmark_enable_random_interleaving, false); + +// Report the result of each benchmark repetitions. When 'true' is specified +// only the mean, standard deviation, and other statistics are reported for +// repeated benchmarks. Affects all reporters. +BM_DEFINE_bool(benchmark_report_aggregates_only, false); + +// Display the result of each benchmark repetitions. When 'true' is specified +// only the mean, standard deviation, and other statistics are displayed for +// repeated benchmarks. Unlike benchmark_report_aggregates_only, only affects +// the display reporter, but *NOT* file reporter, which will still contain +// all the output. +BM_DEFINE_bool(benchmark_display_aggregates_only, false); + +// The format to use for console output. +// Valid values are 'console', 'json', or 'csv'. +BM_DEFINE_string(benchmark_format, "console"); + +// The format to use for file output. +// Valid values are 'console', 'json', or 'csv'. +BM_DEFINE_string(benchmark_out_format, "json"); + +// The file to write additional output to. +BM_DEFINE_string(benchmark_out, ""); + +// Whether to use colors in the output. Valid values: +// 'true'/'yes'/1, 'false'/'no'/0, and 'auto'. 'auto' means to use colors if +// the output is being sent to a terminal and the TERM environment variable is +// set to a terminal type that supports colors. +BM_DEFINE_string(benchmark_color, "auto"); + +// Whether to use tabular format when printing user counters to the console. +// Valid values: 'true'/'yes'/1, 'false'/'no'/0. Defaults to false. +BM_DEFINE_bool(benchmark_counters_tabular, false); + +// List of additional perf counters to collect, in libpfm format. For more +// information about libpfm: https://man7.org/linux/man-pages/man3/libpfm.3.html +BM_DEFINE_string(benchmark_perf_counters, ""); + +// Extra context to include in the output formatted as comma-separated key-value +// pairs. Kept internal as it's only used for parsing from env/command line. +BM_DEFINE_kvpairs(benchmark_context, {}); + +// Set the default time unit to use for reports +// Valid values are 'ns', 'us', 'ms' or 's' +BM_DEFINE_string(benchmark_time_unit, ""); + +// The level of verbose logging to output +BM_DEFINE_int32(v, 0); + +namespace internal { + +std::map* global_context = nullptr; + +BENCHMARK_EXPORT std::map*& GetGlobalContext() { + return global_context; +} + +static void const volatile* volatile global_force_escape_pointer; + +// FIXME: Verify if LTO still messes this up? +void UseCharPointer(char const volatile* const v) { + // We want to escape the pointer `v` so that the compiler can not eliminate + // computations that produced it. To do that, we escape the pointer by storing + // it into a volatile variable, since generally, volatile store, is not + // something the compiler is allowed to elide. + global_force_escape_pointer = reinterpret_cast(v); +} + +} // namespace internal + +State::State(std::string name, IterationCount max_iters, + const std::vector& ranges, int thread_i, int n_threads, + internal::ThreadTimer* timer, internal::ThreadManager* manager, + internal::PerfCountersMeasurement* perf_counters_measurement, + ProfilerManager* profiler_manager) + : total_iterations_(0), + batch_leftover_(0), + max_iterations(max_iters), + started_(false), + finished_(false), + skipped_(internal::NotSkipped), + range_(ranges), + complexity_n_(0), + name_(std::move(name)), + thread_index_(thread_i), + threads_(n_threads), + timer_(timer), + manager_(manager), + perf_counters_measurement_(perf_counters_measurement), + profiler_manager_(profiler_manager) { + BM_CHECK(max_iterations != 0) << "At least one iteration must be run"; + BM_CHECK_LT(thread_index_, threads_) + << "thread_index must be less than threads"; + + // Add counters with correct flag now. If added with `counters[name]` in + // `PauseTiming`, a new `Counter` will be inserted the first time, which + // won't have the flag. Inserting them now also reduces the allocations + // during the benchmark. + if (perf_counters_measurement_) { + for (const std::string& counter_name : + perf_counters_measurement_->names()) { + counters[counter_name] = Counter(0.0, Counter::kAvgIterations); + } + } + + // Note: The use of offsetof below is technically undefined until C++17 + // because State is not a standard layout type. However, all compilers + // currently provide well-defined behavior as an extension (which is + // demonstrated since constexpr evaluation must diagnose all undefined + // behavior). However, GCC and Clang also warn about this use of offsetof, + // which must be suppressed. +#if defined(__INTEL_COMPILER) +#pragma warning push +#pragma warning(disable : 1875) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Winvalid-offsetof" +#endif +#if defined(__NVCC__) +#pragma nv_diagnostic push +#pragma nv_diag_suppress 1427 +#endif +#if defined(__NVCOMPILER) +#pragma diagnostic push +#pragma diag_suppress offset_in_non_POD_nonstandard +#endif + // Offset tests to ensure commonly accessed data is on the first cache line. + const int cache_line_size = 64; + static_assert( + offsetof(State, skipped_) <= (cache_line_size - sizeof(skipped_)), ""); +#if defined(__INTEL_COMPILER) +#pragma warning pop +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif +#if defined(__NVCC__) +#pragma nv_diagnostic pop +#endif +#if defined(__NVCOMPILER) +#pragma diagnostic pop +#endif +} + +void State::PauseTiming() { + // Add in time accumulated so far + BM_CHECK(started_ && !finished_ && !skipped()); + timer_->StopTimer(); + if (perf_counters_measurement_) { + std::vector> measurements; + if (!perf_counters_measurement_->Stop(measurements)) { + BM_CHECK(false) << "Perf counters read the value failed."; + } + for (const auto& name_and_measurement : measurements) { + const std::string& name = name_and_measurement.first; + const double measurement = name_and_measurement.second; + // Counter was inserted with `kAvgIterations` flag by the constructor. + assert(counters.find(name) != counters.end()); + counters[name].value += measurement; + } + } +} + +void State::ResumeTiming() { + BM_CHECK(started_ && !finished_ && !skipped()); + timer_->StartTimer(); + if (perf_counters_measurement_) { + perf_counters_measurement_->Start(); + } +} + +void State::SkipWithMessage(const std::string& msg) { + skipped_ = internal::SkippedWithMessage; + { + MutexLock l(manager_->GetBenchmarkMutex()); + if (internal::NotSkipped == manager_->results.skipped_) { + manager_->results.skip_message_ = msg; + manager_->results.skipped_ = skipped_; + } + } + total_iterations_ = 0; + if (timer_->running()) timer_->StopTimer(); +} + +void State::SkipWithError(const std::string& msg) { + skipped_ = internal::SkippedWithError; + { + MutexLock l(manager_->GetBenchmarkMutex()); + if (internal::NotSkipped == manager_->results.skipped_) { + manager_->results.skip_message_ = msg; + manager_->results.skipped_ = skipped_; + } + } + total_iterations_ = 0; + if (timer_->running()) timer_->StopTimer(); +} + +void State::SetIterationTime(double seconds) { + timer_->SetIterationTime(seconds); +} + +void State::SetLabel(const std::string& label) { + MutexLock l(manager_->GetBenchmarkMutex()); + manager_->results.report_label_ = label; +} + +void State::StartKeepRunning() { + BM_CHECK(!started_ && !finished_); + started_ = true; + total_iterations_ = skipped() ? 0 : max_iterations; + if (BENCHMARK_BUILTIN_EXPECT(profiler_manager_ != nullptr, false)) + profiler_manager_->AfterSetupStart(); + manager_->StartStopBarrier(); + if (!skipped()) ResumeTiming(); +} + +void State::FinishKeepRunning() { + BM_CHECK(started_ && (!finished_ || skipped())); + if (!skipped()) { + PauseTiming(); + } + // Total iterations has now wrapped around past 0. Fix this. + total_iterations_ = 0; + finished_ = true; + manager_->StartStopBarrier(); + if (BENCHMARK_BUILTIN_EXPECT(profiler_manager_ != nullptr, false)) + profiler_manager_->BeforeTeardownStop(); +} + +namespace internal { +namespace { + +// Flushes streams after invoking reporter methods that write to them. This +// ensures users get timely updates even when streams are not line-buffered. +void FlushStreams(BenchmarkReporter* reporter) { + if (!reporter) return; + std::flush(reporter->GetOutputStream()); + std::flush(reporter->GetErrorStream()); +} + +// Reports in both display and file reporters. +void Report(BenchmarkReporter* display_reporter, + BenchmarkReporter* file_reporter, const RunResults& run_results) { + auto report_one = [](BenchmarkReporter* reporter, bool aggregates_only, + const RunResults& results) { + assert(reporter); + // If there are no aggregates, do output non-aggregates. + aggregates_only &= !results.aggregates_only.empty(); + if (!aggregates_only) reporter->ReportRuns(results.non_aggregates); + if (!results.aggregates_only.empty()) + reporter->ReportRuns(results.aggregates_only); + }; + + report_one(display_reporter, run_results.display_report_aggregates_only, + run_results); + if (file_reporter) + report_one(file_reporter, run_results.file_report_aggregates_only, + run_results); + + FlushStreams(display_reporter); + FlushStreams(file_reporter); +} + +void RunBenchmarks(const std::vector& benchmarks, + BenchmarkReporter* display_reporter, + BenchmarkReporter* file_reporter) { + // Note the file_reporter can be null. + BM_CHECK(display_reporter != nullptr); + + // Determine the width of the name field using a minimum width of 10. + bool might_have_aggregates = FLAGS_benchmark_repetitions > 1; + size_t name_field_width = 10; + size_t stat_field_width = 0; + for (const BenchmarkInstance& benchmark : benchmarks) { + name_field_width = + std::max(name_field_width, benchmark.name().str().size()); + might_have_aggregates |= benchmark.repetitions() > 1; + + for (const auto& Stat : benchmark.statistics()) + stat_field_width = std::max(stat_field_width, Stat.name_.size()); + } + if (might_have_aggregates) name_field_width += 1 + stat_field_width; + + // Print header here + BenchmarkReporter::Context context; + context.name_field_width = name_field_width; + + // Keep track of running times of all instances of each benchmark family. + std::map + per_family_reports; + + if (display_reporter->ReportContext(context) && + (!file_reporter || file_reporter->ReportContext(context))) { + FlushStreams(display_reporter); + FlushStreams(file_reporter); + + size_t num_repetitions_total = 0; + + // This perfcounters object needs to be created before the runners vector + // below so it outlasts their lifetime. + PerfCountersMeasurement perfcounters( + StrSplit(FLAGS_benchmark_perf_counters, ',')); + + // Vector of benchmarks to run + std::vector runners; + runners.reserve(benchmarks.size()); + + // Count the number of benchmarks with threads to warn the user in case + // performance counters are used. + int benchmarks_with_threads = 0; + + // Loop through all benchmarks + for (const BenchmarkInstance& benchmark : benchmarks) { + BenchmarkReporter::PerFamilyRunReports* reports_for_family = nullptr; + if (benchmark.complexity() != oNone) + reports_for_family = &per_family_reports[benchmark.family_index()]; + benchmarks_with_threads += (benchmark.threads() > 1); + runners.emplace_back(benchmark, &perfcounters, reports_for_family); + int num_repeats_of_this_instance = runners.back().GetNumRepeats(); + num_repetitions_total += + static_cast(num_repeats_of_this_instance); + if (reports_for_family) + reports_for_family->num_runs_total += num_repeats_of_this_instance; + } + assert(runners.size() == benchmarks.size() && "Unexpected runner count."); + + // The use of performance counters with threads would be unintuitive for + // the average user so we need to warn them about this case + if ((benchmarks_with_threads > 0) && (perfcounters.num_counters() > 0)) { + GetErrorLogInstance() + << "***WARNING*** There are " << benchmarks_with_threads + << " benchmarks with threads and " << perfcounters.num_counters() + << " performance counters were requested. Beware counters will " + "reflect the combined usage across all " + "threads.\n"; + } + + std::vector repetition_indices; + repetition_indices.reserve(num_repetitions_total); + for (size_t runner_index = 0, num_runners = runners.size(); + runner_index != num_runners; ++runner_index) { + const internal::BenchmarkRunner& runner = runners[runner_index]; + std::fill_n(std::back_inserter(repetition_indices), + runner.GetNumRepeats(), runner_index); + } + assert(repetition_indices.size() == num_repetitions_total && + "Unexpected number of repetition indexes."); + + if (FLAGS_benchmark_enable_random_interleaving) { + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(repetition_indices.begin(), repetition_indices.end(), g); + } + + for (size_t repetition_index : repetition_indices) { + internal::BenchmarkRunner& runner = runners[repetition_index]; + runner.DoOneRepetition(); + if (runner.HasRepeatsRemaining()) continue; + // FIXME: report each repetition separately, not all of them in bulk. + + display_reporter->ReportRunsConfig( + runner.GetMinTime(), runner.HasExplicitIters(), runner.GetIters()); + if (file_reporter) + file_reporter->ReportRunsConfig( + runner.GetMinTime(), runner.HasExplicitIters(), runner.GetIters()); + + RunResults run_results = runner.GetResults(); + + // Maybe calculate complexity report + if (const auto* reports_for_family = runner.GetReportsForFamily()) { + if (reports_for_family->num_runs_done == + reports_for_family->num_runs_total) { + auto additional_run_stats = ComputeBigO(reports_for_family->Runs); + run_results.aggregates_only.insert(run_results.aggregates_only.end(), + additional_run_stats.begin(), + additional_run_stats.end()); + per_family_reports.erase( + static_cast(reports_for_family->Runs.front().family_index)); + } + } + + Report(display_reporter, file_reporter, run_results); + } + } + display_reporter->Finalize(); + if (file_reporter) file_reporter->Finalize(); + FlushStreams(display_reporter); + FlushStreams(file_reporter); +} + +// Disable deprecated warnings temporarily because we need to reference +// CSVReporter but don't want to trigger -Werror=-Wdeprecated-declarations +BENCHMARK_DISABLE_DEPRECATED_WARNING + +std::unique_ptr CreateReporter( + std::string const& name, ConsoleReporter::OutputOptions output_opts) { + typedef std::unique_ptr PtrType; + if (name == "console") { + return PtrType(new ConsoleReporter(output_opts)); + } + if (name == "json") { + return PtrType(new JSONReporter()); + } + if (name == "csv") { + return PtrType(new CSVReporter()); + } + std::cerr << "Unexpected format: '" << name << "'\n"; + std::exit(1); +} + +BENCHMARK_RESTORE_DEPRECATED_WARNING + +} // end namespace + +bool IsZero(double n) { + return std::abs(n) < std::numeric_limits::epsilon(); +} + +ConsoleReporter::OutputOptions GetOutputOptions(bool force_no_color) { + int output_opts = ConsoleReporter::OO_Defaults; + auto is_benchmark_color = [force_no_color]() -> bool { + if (force_no_color) { + return false; + } + if (FLAGS_benchmark_color == "auto") { + return IsColorTerminal(); + } + return IsTruthyFlagValue(FLAGS_benchmark_color); + }; + if (is_benchmark_color()) { + output_opts |= ConsoleReporter::OO_Color; + } else { + output_opts &= ~ConsoleReporter::OO_Color; + } + if (FLAGS_benchmark_counters_tabular) { + output_opts |= ConsoleReporter::OO_Tabular; + } else { + output_opts &= ~ConsoleReporter::OO_Tabular; + } + return static_cast(output_opts); +} + +} // end namespace internal + +BenchmarkReporter* CreateDefaultDisplayReporter() { + static auto default_display_reporter = + internal::CreateReporter(FLAGS_benchmark_format, + internal::GetOutputOptions()) + .release(); + return default_display_reporter; +} + +size_t RunSpecifiedBenchmarks() { + return RunSpecifiedBenchmarks(nullptr, nullptr, FLAGS_benchmark_filter); +} + +size_t RunSpecifiedBenchmarks(std::string spec) { + return RunSpecifiedBenchmarks(nullptr, nullptr, std::move(spec)); +} + +size_t RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter) { + return RunSpecifiedBenchmarks(display_reporter, nullptr, + FLAGS_benchmark_filter); +} + +size_t RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter, + std::string spec) { + return RunSpecifiedBenchmarks(display_reporter, nullptr, std::move(spec)); +} + +size_t RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter, + BenchmarkReporter* file_reporter) { + return RunSpecifiedBenchmarks(display_reporter, file_reporter, + FLAGS_benchmark_filter); +} + +size_t RunSpecifiedBenchmarks(BenchmarkReporter* display_reporter, + BenchmarkReporter* file_reporter, + std::string spec) { + if (spec.empty() || spec == "all") + spec = "."; // Regexp that matches all benchmarks + + // Setup the reporters + std::ofstream output_file; + std::unique_ptr default_display_reporter; + std::unique_ptr default_file_reporter; + if (!display_reporter) { + default_display_reporter.reset(CreateDefaultDisplayReporter()); + display_reporter = default_display_reporter.get(); + } + auto& Out = display_reporter->GetOutputStream(); + auto& Err = display_reporter->GetErrorStream(); + + std::string const& fname = FLAGS_benchmark_out; + if (fname.empty() && file_reporter) { + Err << "A custom file reporter was provided but " + "--benchmark_out= was not specified." + << std::endl; + Out.flush(); + Err.flush(); + std::exit(1); + } + if (!fname.empty()) { + output_file.open(fname); + if (!output_file.is_open()) { + Err << "invalid file name: '" << fname << "'" << std::endl; + Out.flush(); + Err.flush(); + std::exit(1); + } + if (!file_reporter) { + default_file_reporter = internal::CreateReporter( + FLAGS_benchmark_out_format, FLAGS_benchmark_counters_tabular + ? ConsoleReporter::OO_Tabular + : ConsoleReporter::OO_None); + file_reporter = default_file_reporter.get(); + } + file_reporter->SetOutputStream(&output_file); + file_reporter->SetErrorStream(&output_file); + } + + std::vector benchmarks; + if (!FindBenchmarksInternal(spec, &benchmarks, &Err)) { + Out.flush(); + Err.flush(); + return 0; + } + + if (benchmarks.empty()) { + Err << "Failed to match any benchmarks against regex: " << spec << "\n"; + Out.flush(); + Err.flush(); + return 0; + } + + if (FLAGS_benchmark_list_tests) { + for (auto const& benchmark : benchmarks) + Out << benchmark.name().str() << "\n"; + } else { + internal::RunBenchmarks(benchmarks, display_reporter, file_reporter); + } + + Out.flush(); + Err.flush(); + return benchmarks.size(); +} + +namespace { +// stores the time unit benchmarks use by default +TimeUnit default_time_unit = kNanosecond; +} // namespace + +TimeUnit GetDefaultTimeUnit() { return default_time_unit; } + +void SetDefaultTimeUnit(TimeUnit unit) { default_time_unit = unit; } + +std::string GetBenchmarkFilter() { return FLAGS_benchmark_filter; } + +void SetBenchmarkFilter(std::string value) { + FLAGS_benchmark_filter = std::move(value); +} + +int32_t GetBenchmarkVerbosity() { return FLAGS_v; } + +void RegisterMemoryManager(MemoryManager* manager) { + internal::memory_manager = manager; +} + +void RegisterProfilerManager(ProfilerManager* manager) { + // Don't allow overwriting an existing manager. + if (manager != nullptr) { + BM_CHECK_EQ(internal::profiler_manager, nullptr); + } + internal::profiler_manager = manager; +} + +void AddCustomContext(const std::string& key, const std::string& value) { + if (internal::global_context == nullptr) { + internal::global_context = new std::map(); + } + if (!internal::global_context->emplace(key, value).second) { + std::cerr << "Failed to add custom context \"" << key << "\" as it already " + << "exists with value \"" << value << "\"\n"; + } +} + +namespace internal { + +void (*HelperPrintf)(); + +void PrintUsageAndExit() { + HelperPrintf(); + exit(0); +} + +void SetDefaultTimeUnitFromFlag(const std::string& time_unit_flag) { + if (time_unit_flag == "s") { + return SetDefaultTimeUnit(kSecond); + } + if (time_unit_flag == "ms") { + return SetDefaultTimeUnit(kMillisecond); + } + if (time_unit_flag == "us") { + return SetDefaultTimeUnit(kMicrosecond); + } + if (time_unit_flag == "ns") { + return SetDefaultTimeUnit(kNanosecond); + } + if (!time_unit_flag.empty()) { + PrintUsageAndExit(); + } +} + +void ParseCommandLineFlags(int* argc, char** argv) { + using namespace benchmark; + BenchmarkReporter::Context::executable_name = + (argc && *argc > 0) ? argv[0] : "unknown"; + for (int i = 1; argc && i < *argc; ++i) { + if (ParseBoolFlag(argv[i], "benchmark_list_tests", + &FLAGS_benchmark_list_tests) || + ParseStringFlag(argv[i], "benchmark_filter", &FLAGS_benchmark_filter) || + ParseStringFlag(argv[i], "benchmark_min_time", + &FLAGS_benchmark_min_time) || + ParseDoubleFlag(argv[i], "benchmark_min_warmup_time", + &FLAGS_benchmark_min_warmup_time) || + ParseInt32Flag(argv[i], "benchmark_repetitions", + &FLAGS_benchmark_repetitions) || + ParseBoolFlag(argv[i], "benchmark_dry_run", &FLAGS_benchmark_dry_run) || + ParseBoolFlag(argv[i], "benchmark_enable_random_interleaving", + &FLAGS_benchmark_enable_random_interleaving) || + ParseBoolFlag(argv[i], "benchmark_report_aggregates_only", + &FLAGS_benchmark_report_aggregates_only) || + ParseBoolFlag(argv[i], "benchmark_display_aggregates_only", + &FLAGS_benchmark_display_aggregates_only) || + ParseStringFlag(argv[i], "benchmark_format", &FLAGS_benchmark_format) || + ParseStringFlag(argv[i], "benchmark_out", &FLAGS_benchmark_out) || + ParseStringFlag(argv[i], "benchmark_out_format", + &FLAGS_benchmark_out_format) || + ParseStringFlag(argv[i], "benchmark_color", &FLAGS_benchmark_color) || + ParseBoolFlag(argv[i], "benchmark_counters_tabular", + &FLAGS_benchmark_counters_tabular) || + ParseStringFlag(argv[i], "benchmark_perf_counters", + &FLAGS_benchmark_perf_counters) || + ParseKeyValueFlag(argv[i], "benchmark_context", + &FLAGS_benchmark_context) || + ParseStringFlag(argv[i], "benchmark_time_unit", + &FLAGS_benchmark_time_unit) || + ParseInt32Flag(argv[i], "v", &FLAGS_v)) { + for (int j = i; j != *argc - 1; ++j) argv[j] = argv[j + 1]; + + --(*argc); + --i; + } else if (IsFlag(argv[i], "help")) { + PrintUsageAndExit(); + } + } + for (auto const* flag : + {&FLAGS_benchmark_format, &FLAGS_benchmark_out_format}) { + if (*flag != "console" && *flag != "json" && *flag != "csv") { + PrintUsageAndExit(); + } + } + SetDefaultTimeUnitFromFlag(FLAGS_benchmark_time_unit); + if (FLAGS_benchmark_color.empty()) { + PrintUsageAndExit(); + } + if (FLAGS_benchmark_dry_run) { + AddCustomContext("dry_run", "true"); + } + for (const auto& kv : FLAGS_benchmark_context) { + AddCustomContext(kv.first, kv.second); + } +} + +int InitializeStreams() { + static std::ios_base::Init init; + return 0; +} + +} // end namespace internal + +std::string GetBenchmarkVersion() { +#ifdef BENCHMARK_VERSION + return {BENCHMARK_VERSION}; +#else + return {""}; +#endif +} + +void PrintDefaultHelp() { + fprintf(stdout, + "benchmark" + " [--benchmark_list_tests={true|false}]\n" + " [--benchmark_filter=]\n" + " [--benchmark_min_time=`x` OR `s` ]\n" + " [--benchmark_min_warmup_time=]\n" + " [--benchmark_repetitions=]\n" + " [--benchmark_dry_run={true|false}]\n" + " [--benchmark_enable_random_interleaving={true|false}]\n" + " [--benchmark_report_aggregates_only={true|false}]\n" + " [--benchmark_display_aggregates_only={true|false}]\n" + " [--benchmark_format=]\n" + " [--benchmark_out=]\n" + " [--benchmark_out_format=]\n" + " [--benchmark_color={auto|true|false}]\n" + " [--benchmark_counters_tabular={true|false}]\n" +#if defined HAVE_LIBPFM + " [--benchmark_perf_counters=,...]\n" +#endif + " [--benchmark_context==,...]\n" + " [--benchmark_time_unit={ns|us|ms|s}]\n" + " [--v=]\n"); +} + +void Initialize(int* argc, char** argv, void (*HelperPrintf)()) { + internal::HelperPrintf = HelperPrintf; + internal::ParseCommandLineFlags(argc, argv); + internal::LogLevel() = FLAGS_v; +} + +void Shutdown() { delete internal::global_context; } + +bool ReportUnrecognizedArguments(int argc, char** argv) { + for (int i = 1; i < argc; ++i) { + fprintf(stderr, "%s: error: unrecognized command-line flag: %s\n", argv[0], + argv[i]); + } + return argc > 1; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/benchmark_api_internal.cc b/third_party/benchmark/src/benchmark_api_internal.cc new file mode 100644 index 0000000..4b569d7 --- /dev/null +++ b/third_party/benchmark/src/benchmark_api_internal.cc @@ -0,0 +1,119 @@ +#include "benchmark_api_internal.h" + +#include + +#include "string_util.h" + +namespace benchmark { +namespace internal { + +BenchmarkInstance::BenchmarkInstance(Benchmark* benchmark, int family_idx, + int per_family_instance_idx, + const std::vector& args, + int thread_count) + : benchmark_(*benchmark), + family_index_(family_idx), + per_family_instance_index_(per_family_instance_idx), + aggregation_report_mode_(benchmark_.aggregation_report_mode_), + args_(args), + time_unit_(benchmark_.GetTimeUnit()), + measure_process_cpu_time_(benchmark_.measure_process_cpu_time_), + use_real_time_(benchmark_.use_real_time_), + use_manual_time_(benchmark_.use_manual_time_), + complexity_(benchmark_.complexity_), + complexity_lambda_(benchmark_.complexity_lambda_), + statistics_(benchmark_.statistics_), + repetitions_(benchmark_.repetitions_), + min_time_(benchmark_.min_time_), + min_warmup_time_(benchmark_.min_warmup_time_), + iterations_(benchmark_.iterations_), + threads_(thread_count) { + name_.function_name = benchmark_.name_; + + size_t arg_i = 0; + for (const auto& arg : args) { + if (!name_.args.empty()) { + name_.args += '/'; + } + + if (arg_i < benchmark->arg_names_.size()) { + const auto& arg_name = benchmark_.arg_names_[arg_i]; + if (!arg_name.empty()) { + name_.args += StrFormat("%s:", arg_name.c_str()); + } + } + + name_.args += StrFormat("%" PRId64, arg); + ++arg_i; + } + + if (!IsZero(benchmark->min_time_)) { + name_.min_time = StrFormat("min_time:%0.3f", benchmark_.min_time_); + } + + if (!IsZero(benchmark->min_warmup_time_)) { + name_.min_warmup_time = + StrFormat("min_warmup_time:%0.3f", benchmark_.min_warmup_time_); + } + + if (benchmark_.iterations_ != 0) { + name_.iterations = StrFormat( + "iterations:%lu", static_cast(benchmark_.iterations_)); + } + + if (benchmark_.repetitions_ != 0) { + name_.repetitions = StrFormat("repeats:%d", benchmark_.repetitions_); + } + + if (benchmark_.measure_process_cpu_time_) { + name_.time_type = "process_time"; + } + + if (benchmark_.use_manual_time_) { + if (!name_.time_type.empty()) { + name_.time_type += '/'; + } + name_.time_type += "manual_time"; + } else if (benchmark_.use_real_time_) { + if (!name_.time_type.empty()) { + name_.time_type += '/'; + } + name_.time_type += "real_time"; + } + + if (!benchmark_.thread_counts_.empty()) { + name_.threads = StrFormat("threads:%d", threads_); + } + + setup_ = benchmark_.setup_; + teardown_ = benchmark_.teardown_; +} + +State BenchmarkInstance::Run( + IterationCount iters, int thread_id, internal::ThreadTimer* timer, + internal::ThreadManager* manager, + internal::PerfCountersMeasurement* perf_counters_measurement, + ProfilerManager* profiler_manager) const { + State st(name_.function_name, iters, args_, thread_id, threads_, timer, + manager, perf_counters_measurement, profiler_manager); + benchmark_.Run(st); + return st; +} + +void BenchmarkInstance::Setup() const { + if (setup_) { + State st(name_.function_name, /*iters*/ 1, args_, /*thread_id*/ 0, threads_, + nullptr, nullptr, nullptr, nullptr); + setup_(st); + } +} + +void BenchmarkInstance::Teardown() const { + if (teardown_) { + State st(name_.function_name, /*iters*/ 1, args_, /*thread_id*/ 0, threads_, + nullptr, nullptr, nullptr, nullptr); + teardown_(st); + } +} +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/src/benchmark_api_internal.h b/third_party/benchmark/src/benchmark_api_internal.h new file mode 100644 index 0000000..659a714 --- /dev/null +++ b/third_party/benchmark/src/benchmark_api_internal.h @@ -0,0 +1,88 @@ +#ifndef BENCHMARK_API_INTERNAL_H +#define BENCHMARK_API_INTERNAL_H + +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "commandlineflags.h" + +namespace benchmark { +namespace internal { + +// Information kept per benchmark we may want to run +class BenchmarkInstance { + public: + BenchmarkInstance(Benchmark* benchmark, int family_index, + int per_family_instance_index, + const std::vector& args, int threads); + + const BenchmarkName& name() const { return name_; } + int family_index() const { return family_index_; } + int per_family_instance_index() const { return per_family_instance_index_; } + AggregationReportMode aggregation_report_mode() const { + return aggregation_report_mode_; + } + TimeUnit time_unit() const { return time_unit_; } + bool measure_process_cpu_time() const { return measure_process_cpu_time_; } + bool use_real_time() const { return use_real_time_; } + bool use_manual_time() const { return use_manual_time_; } + BigO complexity() const { return complexity_; } + BigOFunc* complexity_lambda() const { return complexity_lambda_; } + const std::vector& statistics() const { return statistics_; } + int repetitions() const { return repetitions_; } + double min_time() const { return min_time_; } + double min_warmup_time() const { return min_warmup_time_; } + IterationCount iterations() const { return iterations_; } + int threads() const { return threads_; } + void Setup() const; + void Teardown() const; + + State Run(IterationCount iters, int thread_id, internal::ThreadTimer* timer, + internal::ThreadManager* manager, + internal::PerfCountersMeasurement* perf_counters_measurement, + ProfilerManager* profiler_manager) const; + + private: + BenchmarkName name_; + Benchmark& benchmark_; + const int family_index_; + const int per_family_instance_index_; + AggregationReportMode aggregation_report_mode_; + const std::vector& args_; + TimeUnit time_unit_; + bool measure_process_cpu_time_; + bool use_real_time_; + bool use_manual_time_; + BigO complexity_; + BigOFunc* complexity_lambda_; + UserCounters counters_; + const std::vector& statistics_; + int repetitions_; + double min_time_; + double min_warmup_time_; + IterationCount iterations_; + int threads_; // Number of concurrent threads to us + + typedef void (*callback_function)(const benchmark::State&); + callback_function setup_ = nullptr; + callback_function teardown_ = nullptr; +}; + +bool FindBenchmarksInternal(const std::string& re, + std::vector* benchmarks, + std::ostream* Err); + +bool IsZero(double n); + +BENCHMARK_EXPORT +ConsoleReporter::OutputOptions GetOutputOptions(bool force_no_color = false); + +} // end namespace internal +} // end namespace benchmark + +#endif // BENCHMARK_API_INTERNAL_H diff --git a/third_party/benchmark/src/benchmark_main.cc b/third_party/benchmark/src/benchmark_main.cc new file mode 100644 index 0000000..cd61cd2 --- /dev/null +++ b/third_party/benchmark/src/benchmark_main.cc @@ -0,0 +1,18 @@ +// Copyright 2018 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +BENCHMARK_EXPORT int main(int, char**); +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/src/benchmark_name.cc b/third_party/benchmark/src/benchmark_name.cc new file mode 100644 index 0000000..01676bb --- /dev/null +++ b/third_party/benchmark/src/benchmark_name.cc @@ -0,0 +1,59 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace benchmark { + +namespace { + +// Compute the total size of a pack of std::strings +size_t size_impl() { return 0; } + +template +size_t size_impl(const Head& head, const Tail&... tail) { + return head.size() + size_impl(tail...); +} + +// Join a pack of std::strings using a delimiter +// TODO: use absl::StrJoin +void join_impl(std::string&, char) {} + +template +void join_impl(std::string& s, const char delimiter, const Head& head, + const Tail&... tail) { + if (!s.empty() && !head.empty()) { + s += delimiter; + } + + s += head; + + join_impl(s, delimiter, tail...); +} + +template +std::string join(char delimiter, const Ts&... ts) { + std::string s; + s.reserve(sizeof...(Ts) + size_impl(ts...)); + join_impl(s, delimiter, ts...); + return s; +} +} // namespace + +BENCHMARK_EXPORT +std::string BenchmarkName::str() const { + return join('/', function_name, args, min_time, min_warmup_time, iterations, + repetitions, time_type, threads); +} +} // namespace benchmark diff --git a/third_party/benchmark/src/benchmark_register.cc b/third_party/benchmark/src/benchmark_register.cc new file mode 100644 index 0000000..8ade048 --- /dev/null +++ b/third_party/benchmark/src/benchmark_register.cc @@ -0,0 +1,521 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark_register.h" + +#ifndef BENCHMARK_OS_WINDOWS +#if !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +#include +#endif +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "benchmark_api_internal.h" +#include "check.h" +#include "commandlineflags.h" +#include "complexity.h" +#include "internal_macros.h" +#include "log.h" +#include "mutex.h" +#include "re.h" +#include "statistics.h" +#include "string_util.h" +#include "timers.h" + +namespace benchmark { + +namespace { +// For non-dense Range, intermediate values are powers of kRangeMultiplier. +static constexpr int kRangeMultiplier = 8; + +// The size of a benchmark family determines is the number of inputs to repeat +// the benchmark on. If this is "large" then warn the user during configuration. +static constexpr size_t kMaxFamilySize = 100; + +static constexpr char kDisabledPrefix[] = "DISABLED_"; +} // end namespace + +namespace internal { + +//=============================================================================// +// BenchmarkFamilies +//=============================================================================// + +// Class for managing registered benchmarks. Note that each registered +// benchmark identifies a family of related benchmarks to run. +class BenchmarkFamilies { + public: + static BenchmarkFamilies* GetInstance(); + + // Registers a benchmark family and returns the index assigned to it. + size_t AddBenchmark(std::unique_ptr family); + + // Clear all registered benchmark families. + void ClearBenchmarks(); + + // Extract the list of benchmark instances that match the specified + // regular expression. + bool FindBenchmarks(std::string re, + std::vector* benchmarks, + std::ostream* Err); + + private: + BenchmarkFamilies() {} + + std::vector> families_; + Mutex mutex_; +}; + +BenchmarkFamilies* BenchmarkFamilies::GetInstance() { + static BenchmarkFamilies instance; + return &instance; +} + +size_t BenchmarkFamilies::AddBenchmark(std::unique_ptr family) { + MutexLock l(mutex_); + size_t index = families_.size(); + families_.push_back(std::move(family)); + return index; +} + +void BenchmarkFamilies::ClearBenchmarks() { + MutexLock l(mutex_); + families_.clear(); + families_.shrink_to_fit(); +} + +bool BenchmarkFamilies::FindBenchmarks( + std::string spec, std::vector* benchmarks, + std::ostream* ErrStream) { + BM_CHECK(ErrStream); + auto& Err = *ErrStream; + // Make regular expression out of command-line flag + std::string error_msg; + Regex re; + bool is_negative_filter = false; + if (spec[0] == '-') { + spec.replace(0, 1, ""); + is_negative_filter = true; + } + if (!re.Init(spec, &error_msg)) { + Err << "Could not compile benchmark re: " << error_msg << std::endl; + return false; + } + + // Special list of thread counts to use when none are specified + const std::vector one_thread = {1}; + + int next_family_index = 0; + + MutexLock l(mutex_); + for (std::unique_ptr& family : families_) { + int family_index = next_family_index; + int per_family_instance_index = 0; + + // Family was deleted or benchmark doesn't match + if (!family) continue; + + if (family->ArgsCnt() == -1) { + family->Args({}); + } + const std::vector* thread_counts = + (family->thread_counts_.empty() + ? &one_thread + : &static_cast&>(family->thread_counts_)); + const size_t family_size = family->args_.size() * thread_counts->size(); + // The benchmark will be run at least 'family_size' different inputs. + // If 'family_size' is very large warn the user. + if (family_size > kMaxFamilySize) { + Err << "The number of inputs is very large. " << family->name_ + << " will be repeated at least " << family_size << " times.\n"; + } + // reserve in the special case the regex ".", since we know the final + // family size. this doesn't take into account any disabled benchmarks + // so worst case we reserve more than we need. + if (spec == ".") benchmarks->reserve(benchmarks->size() + family_size); + + for (auto const& args : family->args_) { + for (int num_threads : *thread_counts) { + BenchmarkInstance instance(family.get(), family_index, + per_family_instance_index, args, + num_threads); + + const auto full_name = instance.name().str(); + if (full_name.rfind(kDisabledPrefix, 0) != 0 && + ((re.Match(full_name) && !is_negative_filter) || + (!re.Match(full_name) && is_negative_filter))) { + benchmarks->push_back(std::move(instance)); + + ++per_family_instance_index; + + // Only bump the next family index once we've estabilished that + // at least one instance of this family will be run. + if (next_family_index == family_index) ++next_family_index; + } + } + } + } + return true; +} + +Benchmark* RegisterBenchmarkInternal(Benchmark* bench) { + std::unique_ptr bench_ptr(bench); + BenchmarkFamilies* families = BenchmarkFamilies::GetInstance(); + families->AddBenchmark(std::move(bench_ptr)); + return bench; +} + +// FIXME: This function is a hack so that benchmark.cc can access +// `BenchmarkFamilies` +bool FindBenchmarksInternal(const std::string& re, + std::vector* benchmarks, + std::ostream* Err) { + return BenchmarkFamilies::GetInstance()->FindBenchmarks(re, benchmarks, Err); +} + +//=============================================================================// +// Benchmark +//=============================================================================// + +Benchmark::Benchmark(const std::string& name) + : name_(name), + aggregation_report_mode_(ARM_Unspecified), + time_unit_(GetDefaultTimeUnit()), + use_default_time_unit_(true), + range_multiplier_(kRangeMultiplier), + min_time_(0), + min_warmup_time_(0), + iterations_(0), + repetitions_(0), + measure_process_cpu_time_(false), + use_real_time_(false), + use_manual_time_(false), + complexity_(oNone), + complexity_lambda_(nullptr), + setup_(nullptr), + teardown_(nullptr) { + ComputeStatistics("mean", StatisticsMean); + ComputeStatistics("median", StatisticsMedian); + ComputeStatistics("stddev", StatisticsStdDev); + ComputeStatistics("cv", StatisticsCV, kPercentage); +} + +Benchmark::~Benchmark() {} + +Benchmark* Benchmark::Name(const std::string& name) { + SetName(name); + return this; +} + +Benchmark* Benchmark::Arg(int64_t x) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == 1); + args_.push_back({x}); + return this; +} + +Benchmark* Benchmark::Unit(TimeUnit unit) { + time_unit_ = unit; + use_default_time_unit_ = false; + return this; +} + +Benchmark* Benchmark::Range(int64_t start, int64_t limit) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == 1); + std::vector arglist; + AddRange(&arglist, start, limit, range_multiplier_); + + for (int64_t i : arglist) { + args_.push_back({i}); + } + return this; +} + +Benchmark* Benchmark::Ranges( + const std::vector>& ranges) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == static_cast(ranges.size())); + std::vector> arglists(ranges.size()); + for (std::size_t i = 0; i < ranges.size(); i++) { + AddRange(&arglists[i], ranges[i].first, ranges[i].second, + range_multiplier_); + } + + ArgsProduct(arglists); + + return this; +} + +Benchmark* Benchmark::ArgsProduct( + const std::vector>& arglists) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == static_cast(arglists.size())); + + std::vector indices(arglists.size()); + const std::size_t total = std::accumulate( + std::begin(arglists), std::end(arglists), std::size_t{1}, + [](const std::size_t res, const std::vector& arglist) { + return res * arglist.size(); + }); + std::vector args; + args.reserve(arglists.size()); + for (std::size_t i = 0; i < total; i++) { + for (std::size_t arg = 0; arg < arglists.size(); arg++) { + args.push_back(arglists[arg][indices[arg]]); + } + args_.push_back(args); + args.clear(); + + std::size_t arg = 0; + do { + indices[arg] = (indices[arg] + 1) % arglists[arg].size(); + } while (indices[arg++] == 0 && arg < arglists.size()); + } + + return this; +} + +Benchmark* Benchmark::ArgName(const std::string& name) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == 1); + arg_names_ = {name}; + return this; +} + +Benchmark* Benchmark::ArgNames(const std::vector& names) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == static_cast(names.size())); + arg_names_ = names; + return this; +} + +Benchmark* Benchmark::DenseRange(int64_t start, int64_t limit, int step) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == 1); + BM_CHECK_LE(start, limit); + for (int64_t arg = start; arg <= limit; arg += step) { + args_.push_back({arg}); + } + return this; +} + +Benchmark* Benchmark::Args(const std::vector& args) { + BM_CHECK(ArgsCnt() == -1 || ArgsCnt() == static_cast(args.size())); + args_.push_back(args); + return this; +} + +Benchmark* Benchmark::Apply(void (*custom_arguments)(Benchmark* benchmark)) { + custom_arguments(this); + return this; +} + +Benchmark* Benchmark::Setup(void (*setup)(const benchmark::State&)) { + BM_CHECK(setup != nullptr); + setup_ = setup; + return this; +} + +Benchmark* Benchmark::Teardown(void (*teardown)(const benchmark::State&)) { + BM_CHECK(teardown != nullptr); + teardown_ = teardown; + return this; +} + +Benchmark* Benchmark::RangeMultiplier(int multiplier) { + BM_CHECK(multiplier > 1); + range_multiplier_ = multiplier; + return this; +} + +Benchmark* Benchmark::MinTime(double t) { + BM_CHECK(t > 0.0); + BM_CHECK(iterations_ == 0); + min_time_ = t; + return this; +} + +Benchmark* Benchmark::MinWarmUpTime(double t) { + BM_CHECK(t >= 0.0); + BM_CHECK(iterations_ == 0); + min_warmup_time_ = t; + return this; +} + +Benchmark* Benchmark::Iterations(IterationCount n) { + BM_CHECK(n > 0); + BM_CHECK(IsZero(min_time_)); + BM_CHECK(IsZero(min_warmup_time_)); + iterations_ = n; + return this; +} + +Benchmark* Benchmark::Repetitions(int n) { + BM_CHECK(n > 0); + repetitions_ = n; + return this; +} + +Benchmark* Benchmark::ReportAggregatesOnly(bool value) { + aggregation_report_mode_ = value ? ARM_ReportAggregatesOnly : ARM_Default; + return this; +} + +Benchmark* Benchmark::DisplayAggregatesOnly(bool value) { + // If we were called, the report mode is no longer 'unspecified', in any case. + aggregation_report_mode_ = static_cast( + aggregation_report_mode_ | ARM_Default); + + if (value) { + aggregation_report_mode_ = static_cast( + aggregation_report_mode_ | ARM_DisplayReportAggregatesOnly); + } else { + aggregation_report_mode_ = static_cast( + aggregation_report_mode_ & ~ARM_DisplayReportAggregatesOnly); + } + + return this; +} + +Benchmark* Benchmark::MeasureProcessCPUTime() { + // Can be used together with UseRealTime() / UseManualTime(). + measure_process_cpu_time_ = true; + return this; +} + +Benchmark* Benchmark::UseRealTime() { + BM_CHECK(!use_manual_time_) + << "Cannot set UseRealTime and UseManualTime simultaneously."; + use_real_time_ = true; + return this; +} + +Benchmark* Benchmark::UseManualTime() { + BM_CHECK(!use_real_time_) + << "Cannot set UseRealTime and UseManualTime simultaneously."; + use_manual_time_ = true; + return this; +} + +Benchmark* Benchmark::Complexity(BigO complexity) { + complexity_ = complexity; + return this; +} + +Benchmark* Benchmark::Complexity(BigOFunc* complexity) { + complexity_lambda_ = complexity; + complexity_ = oLambda; + return this; +} + +Benchmark* Benchmark::ComputeStatistics(const std::string& name, + StatisticsFunc* statistics, + StatisticUnit unit) { + statistics_.emplace_back(name, statistics, unit); + return this; +} + +Benchmark* Benchmark::Threads(int t) { + BM_CHECK_GT(t, 0); + thread_counts_.push_back(t); + return this; +} + +Benchmark* Benchmark::ThreadRange(int min_threads, int max_threads) { + BM_CHECK_GT(min_threads, 0); + BM_CHECK_GE(max_threads, min_threads); + + AddRange(&thread_counts_, min_threads, max_threads, 2); + return this; +} + +Benchmark* Benchmark::DenseThreadRange(int min_threads, int max_threads, + int stride) { + BM_CHECK_GT(min_threads, 0); + BM_CHECK_GE(max_threads, min_threads); + BM_CHECK_GE(stride, 1); + + for (auto i = min_threads; i < max_threads; i += stride) { + thread_counts_.push_back(i); + } + thread_counts_.push_back(max_threads); + return this; +} + +Benchmark* Benchmark::ThreadPerCpu() { + thread_counts_.push_back(CPUInfo::Get().num_cpus); + return this; +} + +void Benchmark::SetName(const std::string& name) { name_ = name; } + +const char* Benchmark::GetName() const { return name_.c_str(); } + +int Benchmark::ArgsCnt() const { + if (args_.empty()) { + if (arg_names_.empty()) return -1; + return static_cast(arg_names_.size()); + } + return static_cast(args_.front().size()); +} + +const char* Benchmark::GetArgName(int arg) const { + BM_CHECK_GE(arg, 0); + size_t uarg = static_cast(arg); + BM_CHECK_LT(uarg, arg_names_.size()); + return arg_names_[uarg].c_str(); +} + +TimeUnit Benchmark::GetTimeUnit() const { + return use_default_time_unit_ ? GetDefaultTimeUnit() : time_unit_; +} + +//=============================================================================// +// FunctionBenchmark +//=============================================================================// + +void FunctionBenchmark::Run(State& st) { func_(st); } + +} // end namespace internal + +void ClearRegisteredBenchmarks() { + internal::BenchmarkFamilies::GetInstance()->ClearBenchmarks(); +} + +std::vector CreateRange(int64_t lo, int64_t hi, int multi) { + std::vector args; + internal::AddRange(&args, lo, hi, multi); + return args; +} + +std::vector CreateDenseRange(int64_t start, int64_t limit, int step) { + BM_CHECK_LE(start, limit); + std::vector args; + for (int64_t arg = start; arg <= limit; arg += step) { + args.push_back(arg); + } + return args; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/benchmark_register.h b/third_party/benchmark/src/benchmark_register.h new file mode 100644 index 0000000..be50265 --- /dev/null +++ b/third_party/benchmark/src/benchmark_register.h @@ -0,0 +1,109 @@ +#ifndef BENCHMARK_REGISTER_H +#define BENCHMARK_REGISTER_H + +#include +#include +#include + +#include "check.h" + +namespace benchmark { +namespace internal { + +// Append the powers of 'mult' in the closed interval [lo, hi]. +// Returns iterator to the start of the inserted range. +template +typename std::vector::iterator AddPowers(std::vector* dst, T lo, T hi, + int mult) { + BM_CHECK_GE(lo, 0); + BM_CHECK_GE(hi, lo); + BM_CHECK_GE(mult, 2); + + const size_t start_offset = dst->size(); + + static const T kmax = std::numeric_limits::max(); + + // Space out the values in multiples of "mult" + for (T i = static_cast(1); i <= hi; i = static_cast(i * mult)) { + if (i >= lo) { + dst->push_back(i); + } + // Break the loop here since multiplying by + // 'mult' would move outside of the range of T + if (i > kmax / mult) break; + } + + return dst->begin() + static_cast(start_offset); +} + +template +void AddNegatedPowers(std::vector* dst, T lo, T hi, int mult) { + // We negate lo and hi so we require that they cannot be equal to 'min'. + BM_CHECK_GT(lo, std::numeric_limits::min()); + BM_CHECK_GT(hi, std::numeric_limits::min()); + BM_CHECK_GE(hi, lo); + BM_CHECK_LE(hi, 0); + + // Add positive powers, then negate and reverse. + // Casts necessary since small integers get promoted + // to 'int' when negating. + const auto lo_complement = static_cast(-lo); + const auto hi_complement = static_cast(-hi); + + const auto it = AddPowers(dst, hi_complement, lo_complement, mult); + + std::for_each(it, dst->end(), [](T& t) { t = static_cast(t * -1); }); + std::reverse(it, dst->end()); +} + +template +void AddRange(std::vector* dst, T lo, T hi, int mult) { + static_assert(std::is_integral::value && std::is_signed::value, + "Args type must be a signed integer"); + + BM_CHECK_GE(hi, lo); + BM_CHECK_GE(mult, 2); + + // Add "lo" + dst->push_back(lo); + + // Handle lo == hi as a special case, so we then know + // lo < hi and so it is safe to add 1 to lo and subtract 1 + // from hi without falling outside of the range of T. + if (lo == hi) return; + + // Ensure that lo_inner <= hi_inner below. + if (lo + 1 == hi) { + dst->push_back(hi); + return; + } + + // Add all powers of 'mult' in the range [lo+1, hi-1] (inclusive). + const auto lo_inner = static_cast(lo + 1); + const auto hi_inner = static_cast(hi - 1); + + // Insert negative values + if (lo_inner < 0) { + AddNegatedPowers(dst, lo_inner, std::min(hi_inner, T{-1}), mult); + } + + // Treat 0 as a special case (see discussion on #762). + if (lo < 0 && hi >= 0) { + dst->push_back(0); + } + + // Insert positive values + if (hi_inner > 0) { + AddPowers(dst, std::max(lo_inner, T{1}), hi_inner, mult); + } + + // Add "hi" (if different from last value). + if (hi != dst->back()) { + dst->push_back(hi); + } +} + +} // namespace internal +} // namespace benchmark + +#endif // BENCHMARK_REGISTER_H diff --git a/third_party/benchmark/src/benchmark_runner.cc b/third_party/benchmark/src/benchmark_runner.cc new file mode 100644 index 0000000..463f69f --- /dev/null +++ b/third_party/benchmark/src/benchmark_runner.cc @@ -0,0 +1,539 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark_runner.h" + +#include "benchmark/benchmark.h" +#include "benchmark_api_internal.h" +#include "internal_macros.h" + +#ifndef BENCHMARK_OS_WINDOWS +#if !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +#include +#endif +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check.h" +#include "colorprint.h" +#include "commandlineflags.h" +#include "complexity.h" +#include "counter.h" +#include "internal_macros.h" +#include "log.h" +#include "mutex.h" +#include "perf_counters.h" +#include "re.h" +#include "statistics.h" +#include "string_util.h" +#include "thread_manager.h" +#include "thread_timer.h" + +namespace benchmark { + +BM_DECLARE_bool(benchmark_dry_run); +BM_DECLARE_string(benchmark_min_time); +BM_DECLARE_double(benchmark_min_warmup_time); +BM_DECLARE_int32(benchmark_repetitions); +BM_DECLARE_bool(benchmark_report_aggregates_only); +BM_DECLARE_bool(benchmark_display_aggregates_only); +BM_DECLARE_string(benchmark_perf_counters); + +namespace internal { + +MemoryManager* memory_manager = nullptr; + +ProfilerManager* profiler_manager = nullptr; + +namespace { + +static constexpr IterationCount kMaxIterations = 1000000000000; +const double kDefaultMinTime = + std::strtod(::benchmark::kDefaultMinTimeStr, /*p_end*/ nullptr); + +BenchmarkReporter::Run CreateRunReport( + const benchmark::internal::BenchmarkInstance& b, + const internal::ThreadManager::Result& results, + IterationCount memory_iterations, + const MemoryManager::Result* memory_result, double seconds, + int64_t repetition_index, int64_t repeats) { + // Create report about this benchmark run. + BenchmarkReporter::Run report; + + report.run_name = b.name(); + report.family_index = b.family_index(); + report.per_family_instance_index = b.per_family_instance_index(); + report.skipped = results.skipped_; + report.skip_message = results.skip_message_; + report.report_label = results.report_label_; + // This is the total iterations across all threads. + report.iterations = results.iterations; + report.time_unit = b.time_unit(); + report.threads = b.threads(); + report.repetition_index = repetition_index; + report.repetitions = repeats; + + if (!report.skipped) { + if (b.use_manual_time()) { + report.real_accumulated_time = results.manual_time_used; + } else { + report.real_accumulated_time = results.real_time_used; + } + report.use_real_time_for_initial_big_o = b.use_manual_time(); + report.cpu_accumulated_time = results.cpu_time_used; + report.complexity_n = results.complexity_n; + report.complexity = b.complexity(); + report.complexity_lambda = b.complexity_lambda(); + report.statistics = &b.statistics(); + report.counters = results.counters; + + if (memory_iterations > 0) { + assert(memory_result != nullptr); + report.memory_result = memory_result; + report.allocs_per_iter = + memory_iterations ? static_cast(memory_result->num_allocs) / + static_cast(memory_iterations) + : 0; + } + + internal::Finish(&report.counters, results.iterations, seconds, + b.threads()); + } + return report; +} + +// Execute one thread of benchmark b for the specified number of iterations. +// Adds the stats collected for the thread into manager->results. +void RunInThread(const BenchmarkInstance* b, IterationCount iters, + int thread_id, ThreadManager* manager, + PerfCountersMeasurement* perf_counters_measurement, + ProfilerManager* profiler_manager_) { + internal::ThreadTimer timer( + b->measure_process_cpu_time() + ? internal::ThreadTimer::CreateProcessCpuTime() + : internal::ThreadTimer::Create()); + + State st = b->Run(iters, thread_id, &timer, manager, + perf_counters_measurement, profiler_manager_); + BM_CHECK(st.skipped() || st.iterations() >= st.max_iterations) + << "Benchmark returned before State::KeepRunning() returned false!"; + { + MutexLock l(manager->GetBenchmarkMutex()); + internal::ThreadManager::Result& results = manager->results; + results.iterations += st.iterations(); + results.cpu_time_used += timer.cpu_time_used(); + results.real_time_used += timer.real_time_used(); + results.manual_time_used += timer.manual_time_used(); + results.complexity_n += st.complexity_length_n(); + internal::Increment(&results.counters, st.counters); + } + manager->NotifyThreadComplete(); +} + +double ComputeMinTime(const benchmark::internal::BenchmarkInstance& b, + const BenchTimeType& iters_or_time) { + if (!IsZero(b.min_time())) return b.min_time(); + // If the flag was used to specify number of iters, then return the default + // min_time. + if (iters_or_time.tag == BenchTimeType::ITERS) return kDefaultMinTime; + + return iters_or_time.time; +} + +IterationCount ComputeIters(const benchmark::internal::BenchmarkInstance& b, + const BenchTimeType& iters_or_time) { + if (b.iterations() != 0) return b.iterations(); + + // We've already concluded that this flag is currently used to pass + // iters but do a check here again anyway. + BM_CHECK(iters_or_time.tag == BenchTimeType::ITERS); + return iters_or_time.iters; +} + +} // end namespace + +BenchTimeType ParseBenchMinTime(const std::string& value) { + BenchTimeType ret; + + if (value.empty()) { + ret.tag = BenchTimeType::TIME; + ret.time = 0.0; + return ret; + } + + if (value.back() == 'x') { + char* p_end; + // Reset errno before it's changed by strtol. + errno = 0; + IterationCount num_iters = std::strtol(value.c_str(), &p_end, 10); + + // After a valid parse, p_end should have been set to + // point to the 'x' suffix. + BM_CHECK(errno == 0 && p_end != nullptr && *p_end == 'x') + << "Malformed iters value passed to --benchmark_min_time: `" << value + << "`. Expected --benchmark_min_time=x."; + + ret.tag = BenchTimeType::ITERS; + ret.iters = num_iters; + return ret; + } + + bool has_suffix = value.back() == 's'; + if (!has_suffix) { + BM_VLOG(0) << "Value passed to --benchmark_min_time should have a suffix. " + "Eg., `30s` for 30-seconds."; + } + + char* p_end; + // Reset errno before it's changed by strtod. + errno = 0; + double min_time = std::strtod(value.c_str(), &p_end); + + // After a successful parse, p_end should point to the suffix 's', + // or the end of the string if the suffix was omitted. + BM_CHECK(errno == 0 && p_end != nullptr && + ((has_suffix && *p_end == 's') || *p_end == '\0')) + << "Malformed seconds value passed to --benchmark_min_time: `" << value + << "`. Expected --benchmark_min_time=x."; + + ret.tag = BenchTimeType::TIME; + ret.time = min_time; + + return ret; +} + +BenchmarkRunner::BenchmarkRunner( + const benchmark::internal::BenchmarkInstance& b_, + PerfCountersMeasurement* pcm_, + BenchmarkReporter::PerFamilyRunReports* reports_for_family_) + : b(b_), + reports_for_family(reports_for_family_), + parsed_benchtime_flag(ParseBenchMinTime(FLAGS_benchmark_min_time)), + min_time(FLAGS_benchmark_dry_run + ? 0 + : ComputeMinTime(b_, parsed_benchtime_flag)), + min_warmup_time( + FLAGS_benchmark_dry_run + ? 0 + : ((!IsZero(b.min_time()) && b.min_warmup_time() > 0.0) + ? b.min_warmup_time() + : FLAGS_benchmark_min_warmup_time)), + warmup_done(FLAGS_benchmark_dry_run ? true : !(min_warmup_time > 0.0)), + repeats(FLAGS_benchmark_dry_run + ? 1 + : (b.repetitions() != 0 ? b.repetitions() + : FLAGS_benchmark_repetitions)), + has_explicit_iteration_count(b.iterations() != 0 || + parsed_benchtime_flag.tag == + BenchTimeType::ITERS), + pool(static_cast(b.threads() - 1)), + iters(FLAGS_benchmark_dry_run + ? 1 + : (has_explicit_iteration_count + ? ComputeIters(b_, parsed_benchtime_flag) + : 1)), + perf_counters_measurement_ptr(pcm_) { + run_results.display_report_aggregates_only = + (FLAGS_benchmark_report_aggregates_only || + FLAGS_benchmark_display_aggregates_only); + run_results.file_report_aggregates_only = + FLAGS_benchmark_report_aggregates_only; + if (b.aggregation_report_mode() != internal::ARM_Unspecified) { + run_results.display_report_aggregates_only = + (b.aggregation_report_mode() & + internal::ARM_DisplayReportAggregatesOnly); + run_results.file_report_aggregates_only = + (b.aggregation_report_mode() & internal::ARM_FileReportAggregatesOnly); + BM_CHECK(FLAGS_benchmark_perf_counters.empty() || + (perf_counters_measurement_ptr->num_counters() == 0)) + << "Perf counters were requested but could not be set up."; + } +} + +BenchmarkRunner::IterationResults BenchmarkRunner::DoNIterations() { + BM_VLOG(2) << "Running " << b.name().str() << " for " << iters << "\n"; + + std::unique_ptr manager; + manager.reset(new internal::ThreadManager(b.threads())); + + // Run all but one thread in separate threads + for (std::size_t ti = 0; ti < pool.size(); ++ti) { + pool[ti] = std::thread(&RunInThread, &b, iters, static_cast(ti + 1), + manager.get(), perf_counters_measurement_ptr, + /*profiler_manager=*/nullptr); + } + // And run one thread here directly. + // (If we were asked to run just one thread, we don't create new threads.) + // Yes, we need to do this here *after* we start the separate threads. + RunInThread(&b, iters, 0, manager.get(), perf_counters_measurement_ptr, + /*profiler_manager=*/nullptr); + + // The main thread has finished. Now let's wait for the other threads. + manager->WaitForAllThreads(); + for (std::thread& thread : pool) thread.join(); + + IterationResults i; + // Acquire the measurements/counters from the manager, UNDER THE LOCK! + { + MutexLock l(manager->GetBenchmarkMutex()); + i.results = manager->results; + } + + // And get rid of the manager. + manager.reset(); + + BM_VLOG(2) << "Ran in " << i.results.cpu_time_used << "/" + << i.results.real_time_used << "\n"; + + // By using KeepRunningBatch a benchmark can iterate more times than + // requested, so take the iteration count from i.results. + i.iters = i.results.iterations / b.threads(); + + // Base decisions off of real time if requested by this benchmark. + i.seconds = i.results.cpu_time_used; + if (b.use_manual_time()) { + i.seconds = i.results.manual_time_used; + } else if (b.use_real_time()) { + i.seconds = i.results.real_time_used; + } + + return i; +} + +IterationCount BenchmarkRunner::PredictNumItersNeeded( + const IterationResults& i) const { + // See how much iterations should be increased by. + // Note: Avoid division by zero with max(seconds, 1ns). + double multiplier = GetMinTimeToApply() * 1.4 / std::max(i.seconds, 1e-9); + // If our last run was at least 10% of FLAGS_benchmark_min_time then we + // use the multiplier directly. + // Otherwise we use at most 10 times expansion. + // NOTE: When the last run was at least 10% of the min time the max + // expansion should be 14x. + const bool is_significant = (i.seconds / GetMinTimeToApply()) > 0.1; + multiplier = is_significant ? multiplier : 10.0; + + // So what seems to be the sufficiently-large iteration count? Round up. + const IterationCount max_next_iters = static_cast( + std::llround(std::max(multiplier * static_cast(i.iters), + static_cast(i.iters) + 1.0))); + // But we do have *some* limits though.. + const IterationCount next_iters = std::min(max_next_iters, kMaxIterations); + + BM_VLOG(3) << "Next iters: " << next_iters << ", " << multiplier << "\n"; + return next_iters; // round up before conversion to integer. +} + +bool BenchmarkRunner::ShouldReportIterationResults( + const IterationResults& i) const { + // Determine if this run should be reported; + // Either it has run for a sufficient amount of time + // or because an error was reported. + return i.results.skipped_ || FLAGS_benchmark_dry_run || + i.iters >= kMaxIterations || // Too many iterations already. + i.seconds >= + GetMinTimeToApply() || // The elapsed time is large enough. + // CPU time is specified but the elapsed real time greatly exceeds + // the minimum time. + // Note that user provided timers are except from this test. + ((i.results.real_time_used >= 5 * GetMinTimeToApply()) && + !b.use_manual_time()); +} + +double BenchmarkRunner::GetMinTimeToApply() const { + // In order to re-use functionality to run and measure benchmarks for running + // a warmup phase of the benchmark, we need a way of telling whether to apply + // min_time or min_warmup_time. This function will figure out if we are in the + // warmup phase and therefore need to apply min_warmup_time or if we already + // in the benchmarking phase and min_time needs to be applied. + return warmup_done ? min_time : min_warmup_time; +} + +void BenchmarkRunner::FinishWarmUp(const IterationCount& i) { + warmup_done = true; + iters = i; +} + +void BenchmarkRunner::RunWarmUp() { + // Use the same mechanisms for warming up the benchmark as used for actually + // running and measuring the benchmark. + IterationResults i_warmup; + // Dont use the iterations determined in the warmup phase for the actual + // measured benchmark phase. While this may be a good starting point for the + // benchmark and it would therefore get rid of the need to figure out how many + // iterations are needed if min_time is set again, this may also be a complete + // wrong guess since the warmup loops might be considerably slower (e.g + // because of caching effects). + const IterationCount i_backup = iters; + + for (;;) { + b.Setup(); + i_warmup = DoNIterations(); + b.Teardown(); + + const bool finish = ShouldReportIterationResults(i_warmup); + + if (finish) { + FinishWarmUp(i_backup); + break; + } + + // Although we are running "only" a warmup phase where running enough + // iterations at once without measuring time isn't as important as it is for + // the benchmarking phase, we still do it the same way as otherwise it is + // very confusing for the user to know how to choose a proper value for + // min_warmup_time if a different approach on running it is used. + iters = PredictNumItersNeeded(i_warmup); + assert(iters > i_warmup.iters && + "if we did more iterations than we want to do the next time, " + "then we should have accepted the current iteration run."); + } +} + +MemoryManager::Result* BenchmarkRunner::RunMemoryManager( + IterationCount memory_iterations) { + // TODO(vyng): Consider making BenchmarkReporter::Run::memory_result an + // optional so we don't have to own the Result here. + // Can't do it now due to cxx03. + memory_results.push_back(MemoryManager::Result()); + MemoryManager::Result* memory_result = &memory_results.back(); + memory_manager->Start(); + std::unique_ptr manager; + manager.reset(new internal::ThreadManager(1)); + b.Setup(); + RunInThread(&b, memory_iterations, 0, manager.get(), + perf_counters_measurement_ptr, + /*profiler_manager=*/nullptr); + manager->WaitForAllThreads(); + manager.reset(); + b.Teardown(); + memory_manager->Stop(*memory_result); + return memory_result; +} + +void BenchmarkRunner::RunProfilerManager() { + // TODO: Provide a way to specify the number of iterations. + IterationCount profile_iterations = 1; + std::unique_ptr manager; + manager.reset(new internal::ThreadManager(1)); + b.Setup(); + RunInThread(&b, profile_iterations, 0, manager.get(), + /*perf_counters_measurement_ptr=*/nullptr, + /*profiler_manager=*/profiler_manager); + manager->WaitForAllThreads(); + manager.reset(); + b.Teardown(); +} + +void BenchmarkRunner::DoOneRepetition() { + assert(HasRepeatsRemaining() && "Already done all repetitions?"); + + const bool is_the_first_repetition = num_repetitions_done == 0; + + // In case a warmup phase is requested by the benchmark, run it now. + // After running the warmup phase the BenchmarkRunner should be in a state as + // this warmup never happened except the fact that warmup_done is set. Every + // other manipulation of the BenchmarkRunner instance would be a bug! Please + // fix it. + if (!warmup_done) RunWarmUp(); + + IterationResults i; + // We *may* be gradually increasing the length (iteration count) + // of the benchmark until we decide the results are significant. + // And once we do, we report those last results and exit. + // Please do note that the if there are repetitions, the iteration count + // is *only* calculated for the *first* repetition, and other repetitions + // simply use that precomputed iteration count. + for (;;) { + b.Setup(); + i = DoNIterations(); + b.Teardown(); + + // Do we consider the results to be significant? + // If we are doing repetitions, and the first repetition was already done, + // it has calculated the correct iteration time, so we have run that very + // iteration count just now. No need to calculate anything. Just report. + // Else, the normal rules apply. + const bool results_are_significant = !is_the_first_repetition || + has_explicit_iteration_count || + ShouldReportIterationResults(i); + + if (results_are_significant) break; // Good, let's report them! + + // Nope, bad iteration. Let's re-estimate the hopefully-sufficient + // iteration count, and run the benchmark again... + + iters = PredictNumItersNeeded(i); + assert(iters > i.iters && + "if we did more iterations than we want to do the next time, " + "then we should have accepted the current iteration run."); + } + + // Produce memory measurements if requested. + MemoryManager::Result* memory_result = nullptr; + IterationCount memory_iterations = 0; + if (memory_manager != nullptr) { + // Only run a few iterations to reduce the impact of one-time + // allocations in benchmarks that are not properly managed. + memory_iterations = std::min(16, iters); + memory_result = RunMemoryManager(memory_iterations); + } + + if (profiler_manager != nullptr) { + RunProfilerManager(); + } + + // Ok, now actually report. + BenchmarkReporter::Run report = + CreateRunReport(b, i.results, memory_iterations, memory_result, i.seconds, + num_repetitions_done, repeats); + + if (reports_for_family) { + ++reports_for_family->num_runs_done; + if (!report.skipped) reports_for_family->Runs.push_back(report); + } + + run_results.non_aggregates.push_back(report); + + ++num_repetitions_done; +} + +RunResults&& BenchmarkRunner::GetResults() { + assert(!HasRepeatsRemaining() && "Did not run all repetitions yet?"); + + // Calculate additional statistics over the repetitions of this instance. + run_results.aggregates_only = ComputeStats(run_results.non_aggregates); + + return std::move(run_results); +} + +} // end namespace internal + +} // end namespace benchmark diff --git a/third_party/benchmark/src/benchmark_runner.h b/third_party/benchmark/src/benchmark_runner.h new file mode 100644 index 0000000..6e5ceb3 --- /dev/null +++ b/third_party/benchmark/src/benchmark_runner.h @@ -0,0 +1,129 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BENCHMARK_RUNNER_H_ +#define BENCHMARK_RUNNER_H_ + +#include +#include + +#include "benchmark_api_internal.h" +#include "internal_macros.h" +#include "perf_counters.h" +#include "thread_manager.h" + +namespace benchmark { + +namespace internal { + +extern MemoryManager* memory_manager; +extern ProfilerManager* profiler_manager; + +struct RunResults { + std::vector non_aggregates; + std::vector aggregates_only; + + bool display_report_aggregates_only = false; + bool file_report_aggregates_only = false; +}; + +struct BENCHMARK_EXPORT BenchTimeType { + enum { ITERS, TIME } tag; + union { + IterationCount iters; + double time; + }; +}; + +BENCHMARK_EXPORT +BenchTimeType ParseBenchMinTime(const std::string& value); + +class BenchmarkRunner { + public: + BenchmarkRunner(const benchmark::internal::BenchmarkInstance& b_, + benchmark::internal::PerfCountersMeasurement* pmc_, + BenchmarkReporter::PerFamilyRunReports* reports_for_family); + + int GetNumRepeats() const { return repeats; } + + bool HasRepeatsRemaining() const { + return GetNumRepeats() != num_repetitions_done; + } + + void DoOneRepetition(); + + RunResults&& GetResults(); + + BenchmarkReporter::PerFamilyRunReports* GetReportsForFamily() const { + return reports_for_family; + } + + double GetMinTime() const { return min_time; } + + bool HasExplicitIters() const { return has_explicit_iteration_count; } + + IterationCount GetIters() const { return iters; } + + private: + RunResults run_results; + + const benchmark::internal::BenchmarkInstance& b; + BenchmarkReporter::PerFamilyRunReports* reports_for_family; + + BenchTimeType parsed_benchtime_flag; + const double min_time; + const double min_warmup_time; + bool warmup_done; + const int repeats; + const bool has_explicit_iteration_count; + + int num_repetitions_done = 0; + + std::vector pool; + + std::vector memory_results; + + IterationCount iters; // preserved between repetitions! + // So only the first repetition has to find/calculate it, + // the other repetitions will just use that precomputed iteration count. + + PerfCountersMeasurement* const perf_counters_measurement_ptr = nullptr; + + struct IterationResults { + internal::ThreadManager::Result results; + IterationCount iters; + double seconds; + }; + IterationResults DoNIterations(); + + MemoryManager::Result* RunMemoryManager(IterationCount memory_iterations); + + void RunProfilerManager(); + + IterationCount PredictNumItersNeeded(const IterationResults& i) const; + + bool ShouldReportIterationResults(const IterationResults& i) const; + + double GetMinTimeToApply() const; + + void FinishWarmUp(const IterationCount& i); + + void RunWarmUp(); +}; + +} // namespace internal + +} // end namespace benchmark + +#endif // BENCHMARK_RUNNER_H_ diff --git a/third_party/benchmark/src/check.cc b/third_party/benchmark/src/check.cc new file mode 100644 index 0000000..5f7526e --- /dev/null +++ b/third_party/benchmark/src/check.cc @@ -0,0 +1,11 @@ +#include "check.h" + +namespace benchmark { +namespace internal { + +static AbortHandlerT* handler = &std::abort; + +BENCHMARK_EXPORT AbortHandlerT*& GetAbortHandler() { return handler; } + +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/src/check.h b/third_party/benchmark/src/check.h new file mode 100644 index 0000000..c1cd5e8 --- /dev/null +++ b/third_party/benchmark/src/check.h @@ -0,0 +1,106 @@ +#ifndef CHECK_H_ +#define CHECK_H_ + +#include +#include +#include + +#include "benchmark/export.h" +#include "internal_macros.h" +#include "log.h" + +#if defined(__GNUC__) || defined(__clang__) +#define BENCHMARK_NOEXCEPT noexcept +#define BENCHMARK_NOEXCEPT_OP(x) noexcept(x) +#elif defined(_MSC_VER) && !defined(__clang__) +#if _MSC_VER >= 1900 +#define BENCHMARK_NOEXCEPT noexcept +#define BENCHMARK_NOEXCEPT_OP(x) noexcept(x) +#else +#define BENCHMARK_NOEXCEPT +#define BENCHMARK_NOEXCEPT_OP(x) +#endif +#define __func__ __FUNCTION__ +#else +#define BENCHMARK_NOEXCEPT +#define BENCHMARK_NOEXCEPT_OP(x) +#endif + +namespace benchmark { +namespace internal { + +typedef void(AbortHandlerT)(); + +BENCHMARK_EXPORT +AbortHandlerT*& GetAbortHandler(); + +BENCHMARK_NORETURN inline void CallAbortHandler() { + GetAbortHandler()(); + std::abort(); // fallback to enforce noreturn +} + +// CheckHandler is the class constructed by failing BM_CHECK macros. +// CheckHandler will log information about the failures and abort when it is +// destructed. +class CheckHandler { + public: + CheckHandler(const char* check, const char* file, const char* func, int line) + : log_(GetErrorLogInstance()) { + log_ << file << ":" << line << ": " << func << ": Check `" << check + << "' failed. "; + } + + LogType& GetLog() { return log_; } + +#if defined(COMPILER_MSVC) +#pragma warning(push) +#pragma warning(disable : 4722) +#endif + BENCHMARK_NORETURN ~CheckHandler() BENCHMARK_NOEXCEPT_OP(false) { + log_ << std::endl; + CallAbortHandler(); + } +#if defined(COMPILER_MSVC) +#pragma warning(pop) +#endif + + CheckHandler& operator=(const CheckHandler&) = delete; + CheckHandler(const CheckHandler&) = delete; + CheckHandler() = delete; + + private: + LogType& log_; +}; + +} // end namespace internal +} // end namespace benchmark + +// The BM_CHECK macro returns a std::ostream object that can have extra +// information written to it. +#ifndef NDEBUG +#define BM_CHECK(b) \ + (b ? ::benchmark::internal::GetNullLogInstance() \ + : ::benchmark::internal::CheckHandler(#b, __FILE__, __func__, __LINE__) \ + .GetLog()) +#else +#define BM_CHECK(b) ::benchmark::internal::GetNullLogInstance() +#endif + +// clang-format off +// preserve whitespacing between operators for alignment +#define BM_CHECK_EQ(a, b) BM_CHECK((a) == (b)) +#define BM_CHECK_NE(a, b) BM_CHECK((a) != (b)) +#define BM_CHECK_GE(a, b) BM_CHECK((a) >= (b)) +#define BM_CHECK_LE(a, b) BM_CHECK((a) <= (b)) +#define BM_CHECK_GT(a, b) BM_CHECK((a) > (b)) +#define BM_CHECK_LT(a, b) BM_CHECK((a) < (b)) + +#define BM_CHECK_FLOAT_EQ(a, b, eps) BM_CHECK(std::fabs((a) - (b)) < (eps)) +#define BM_CHECK_FLOAT_NE(a, b, eps) BM_CHECK(std::fabs((a) - (b)) >= (eps)) +#define BM_CHECK_FLOAT_GE(a, b, eps) BM_CHECK((a) - (b) > -(eps)) +#define BM_CHECK_FLOAT_LE(a, b, eps) BM_CHECK((b) - (a) > -(eps)) +#define BM_CHECK_FLOAT_GT(a, b, eps) BM_CHECK((a) - (b) > (eps)) +#define BM_CHECK_FLOAT_LT(a, b, eps) BM_CHECK((b) - (a) > (eps)) +//clang-format on + +#endif // CHECK_H_ diff --git a/third_party/benchmark/src/colorprint.cc b/third_party/benchmark/src/colorprint.cc new file mode 100644 index 0000000..fd1971a --- /dev/null +++ b/third_party/benchmark/src/colorprint.cc @@ -0,0 +1,206 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "colorprint.h" + +#include +#include +#include +#include +#include +#include + +#include "check.h" +#include "internal_macros.h" + +#ifdef BENCHMARK_OS_WINDOWS +#include +#include +#else +#include +#endif // BENCHMARK_OS_WINDOWS + +namespace benchmark { +namespace { +#ifdef BENCHMARK_OS_WINDOWS +typedef WORD PlatformColorCode; +#else +typedef const char* PlatformColorCode; +#endif + +PlatformColorCode GetPlatformColorCode(LogColor color) { +#ifdef BENCHMARK_OS_WINDOWS + switch (color) { + case COLOR_RED: + return FOREGROUND_RED; + case COLOR_GREEN: + return FOREGROUND_GREEN; + case COLOR_YELLOW: + return FOREGROUND_RED | FOREGROUND_GREEN; + case COLOR_BLUE: + return FOREGROUND_BLUE; + case COLOR_MAGENTA: + return FOREGROUND_BLUE | FOREGROUND_RED; + case COLOR_CYAN: + return FOREGROUND_BLUE | FOREGROUND_GREEN; + case COLOR_WHITE: // fall through to default + default: + return 0; + } +#else + switch (color) { + case COLOR_RED: + return "1"; + case COLOR_GREEN: + return "2"; + case COLOR_YELLOW: + return "3"; + case COLOR_BLUE: + return "4"; + case COLOR_MAGENTA: + return "5"; + case COLOR_CYAN: + return "6"; + case COLOR_WHITE: + return "7"; + default: + return nullptr; + }; +#endif +} + +} // end namespace + +std::string FormatString(const char* msg, va_list args) { + // we might need a second shot at this, so pre-emptivly make a copy + va_list args_cp; + va_copy(args_cp, args); + + std::size_t size = 256; + char local_buff[256]; + auto ret = vsnprintf(local_buff, size, msg, args_cp); + + va_end(args_cp); + + // currently there is no error handling for failure, so this is hack. + BM_CHECK(ret >= 0); + + if (ret == 0) { // handle empty expansion + return {}; + } + if (static_cast(ret) < size) { + return local_buff; + } + // we did not provide a long enough buffer on our first attempt. + size = static_cast(ret) + 1; // + 1 for the null byte + std::unique_ptr buff(new char[size]); + ret = vsnprintf(buff.get(), size, msg, args); + BM_CHECK(ret > 0 && (static_cast(ret)) < size); + return buff.get(); +} + +std::string FormatString(const char* msg, ...) { + va_list args; + va_start(args, msg); + auto tmp = FormatString(msg, args); + va_end(args); + return tmp; +} + +void ColorPrintf(std::ostream& out, LogColor color, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + ColorPrintf(out, color, fmt, args); + va_end(args); +} + +void ColorPrintf(std::ostream& out, LogColor color, const char* fmt, + va_list args) { +#ifdef BENCHMARK_OS_WINDOWS + ((void)out); // suppress unused warning + + const HANDLE stdout_handle = GetStdHandle(STD_OUTPUT_HANDLE); + + // Gets the current text color. + CONSOLE_SCREEN_BUFFER_INFO buffer_info; + GetConsoleScreenBufferInfo(stdout_handle, &buffer_info); + const WORD original_color_attrs = buffer_info.wAttributes; + + // We need to flush the stream buffers into the console before each + // SetConsoleTextAttribute call lest it affect the text that is already + // printed but has not yet reached the console. + out.flush(); + + const WORD original_background_attrs = + original_color_attrs & (BACKGROUND_RED | BACKGROUND_GREEN | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + + SetConsoleTextAttribute(stdout_handle, GetPlatformColorCode(color) | + FOREGROUND_INTENSITY | + original_background_attrs); + out << FormatString(fmt, args); + + out.flush(); + // Restores the text and background color. + SetConsoleTextAttribute(stdout_handle, original_color_attrs); +#else + const char* color_code = GetPlatformColorCode(color); + if (color_code) out << FormatString("\033[0;3%sm", color_code); + out << FormatString(fmt, args) << "\033[m"; +#endif +} + +bool IsColorTerminal() { +#if BENCHMARK_OS_WINDOWS + // On Windows the TERM variable is usually not set, but the + // console there does support colors. + return 0 != _isatty(_fileno(stdout)); +#else + // On non-Windows platforms, we rely on the TERM variable. This list of + // supported TERM values is copied from Google Test: + // . + const char* const SUPPORTED_TERM_VALUES[] = { + "xterm", + "xterm-color", + "xterm-256color", + "screen", + "screen-256color", + "tmux", + "tmux-256color", + "rxvt-unicode", + "rxvt-unicode-256color", + "linux", + "cygwin", + "xterm-kitty", + "alacritty", + "foot", + "foot-extra", + "wezterm", + }; + + const char* const term = getenv("TERM"); + + bool term_supports_color = false; + for (const char* candidate : SUPPORTED_TERM_VALUES) { + if (term && 0 == strcmp(term, candidate)) { + term_supports_color = true; + break; + } + } + + return 0 != isatty(fileno(stdout)) && term_supports_color; +#endif // BENCHMARK_OS_WINDOWS +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/colorprint.h b/third_party/benchmark/src/colorprint.h new file mode 100644 index 0000000..9f6fab9 --- /dev/null +++ b/third_party/benchmark/src/colorprint.h @@ -0,0 +1,33 @@ +#ifndef BENCHMARK_COLORPRINT_H_ +#define BENCHMARK_COLORPRINT_H_ + +#include +#include +#include + +namespace benchmark { +enum LogColor { + COLOR_DEFAULT, + COLOR_RED, + COLOR_GREEN, + COLOR_YELLOW, + COLOR_BLUE, + COLOR_MAGENTA, + COLOR_CYAN, + COLOR_WHITE +}; + +std::string FormatString(const char* msg, va_list args); +std::string FormatString(const char* msg, ...); + +void ColorPrintf(std::ostream& out, LogColor color, const char* fmt, + va_list args); +void ColorPrintf(std::ostream& out, LogColor color, const char* fmt, ...); + +// Returns true if stdout appears to be a terminal that supports colored +// output, false otherwise. +bool IsColorTerminal(); + +} // end namespace benchmark + +#endif // BENCHMARK_COLORPRINT_H_ diff --git a/third_party/benchmark/src/commandlineflags.cc b/third_party/benchmark/src/commandlineflags.cc new file mode 100644 index 0000000..dcb4149 --- /dev/null +++ b/third_party/benchmark/src/commandlineflags.cc @@ -0,0 +1,298 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "commandlineflags.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/string_util.h" + +namespace benchmark { +namespace { + +// Parses 'str' for a 32-bit signed integer. If successful, writes +// the result to *value and returns true; otherwise leaves *value +// unchanged and returns false. +bool ParseInt32(const std::string& src_text, const char* str, int32_t* value) { + // Parses the environment variable as a decimal integer. + char* end = nullptr; + const long long_value = strtol(str, &end, 10); // NOLINT + + // Has strtol() consumed all characters in the string? + if (*end != '\0') { + // No - an invalid character was encountered. + std::cerr << src_text << " is expected to be a 32-bit integer, " + << "but actually has value \"" << str << "\".\n"; + return false; + } + + // Is the parsed value in the range of an Int32? + const int32_t result = static_cast(long_value); + if (long_value == std::numeric_limits::max() || + long_value == std::numeric_limits::min() || + // The parsed value overflows as a long. (strtol() returns + // LONG_MAX or LONG_MIN when the input overflows.) + result != long_value + // The parsed value overflows as an Int32. + ) { + std::cerr << src_text << " is expected to be a 32-bit integer, " + << "but actually has value \"" << str << "\", " + << "which overflows.\n"; + return false; + } + + *value = result; + return true; +} + +// Parses 'str' for a double. If successful, writes the result to *value and +// returns true; otherwise leaves *value unchanged and returns false. +bool ParseDouble(const std::string& src_text, const char* str, double* value) { + // Parses the environment variable as a decimal integer. + char* end = nullptr; + const double double_value = strtod(str, &end); // NOLINT + + // Has strtol() consumed all characters in the string? + if (*end != '\0') { + // No - an invalid character was encountered. + std::cerr << src_text << " is expected to be a double, " + << "but actually has value \"" << str << "\".\n"; + return false; + } + + *value = double_value; + return true; +} + +// Parses 'str' into KV pairs. If successful, writes the result to *value and +// returns true; otherwise leaves *value unchanged and returns false. +bool ParseKvPairs(const std::string& src_text, const char* str, + std::map* value) { + std::map kvs; + for (const auto& kvpair : StrSplit(str, ',')) { + const auto kv = StrSplit(kvpair, '='); + if (kv.size() != 2) { + std::cerr << src_text << " is expected to be a comma-separated list of " + << "= strings, but actually has value \"" << str + << "\".\n"; + return false; + } + if (!kvs.emplace(kv[0], kv[1]).second) { + std::cerr << src_text << " is expected to contain unique keys but key \"" + << kv[0] << "\" was repeated.\n"; + return false; + } + } + + *value = kvs; + return true; +} + +// Returns the name of the environment variable corresponding to the +// given flag. For example, FlagToEnvVar("foo") will return +// "BENCHMARK_FOO" in the open-source version. +static std::string FlagToEnvVar(const char* flag) { + const std::string flag_str(flag); + + std::string env_var; + for (size_t i = 0; i != flag_str.length(); ++i) + env_var += static_cast(::toupper(flag_str.c_str()[i])); + + return env_var; +} + +} // namespace + +BENCHMARK_EXPORT +bool BoolFromEnv(const char* flag, bool default_val) { + const std::string env_var = FlagToEnvVar(flag); + const char* const value_str = getenv(env_var.c_str()); + return value_str == nullptr ? default_val : IsTruthyFlagValue(value_str); +} + +BENCHMARK_EXPORT +int32_t Int32FromEnv(const char* flag, int32_t default_val) { + const std::string env_var = FlagToEnvVar(flag); + const char* const value_str = getenv(env_var.c_str()); + int32_t value = default_val; + if (value_str == nullptr || + !ParseInt32(std::string("Environment variable ") + env_var, value_str, + &value)) { + return default_val; + } + return value; +} + +BENCHMARK_EXPORT +double DoubleFromEnv(const char* flag, double default_val) { + const std::string env_var = FlagToEnvVar(flag); + const char* const value_str = getenv(env_var.c_str()); + double value = default_val; + if (value_str == nullptr || + !ParseDouble(std::string("Environment variable ") + env_var, value_str, + &value)) { + return default_val; + } + return value; +} + +BENCHMARK_EXPORT +const char* StringFromEnv(const char* flag, const char* default_val) { + const std::string env_var = FlagToEnvVar(flag); + const char* const value = getenv(env_var.c_str()); + return value == nullptr ? default_val : value; +} + +BENCHMARK_EXPORT +std::map KvPairsFromEnv( + const char* flag, std::map default_val) { + const std::string env_var = FlagToEnvVar(flag); + const char* const value_str = getenv(env_var.c_str()); + + if (value_str == nullptr) return default_val; + + std::map value; + if (!ParseKvPairs("Environment variable " + env_var, value_str, &value)) { + return default_val; + } + return value; +} + +// Parses a string as a command line flag. The string should have +// the format "--flag=value". When def_optional is true, the "=value" +// part can be omitted. +// +// Returns the value of the flag, or nullptr if the parsing failed. +const char* ParseFlagValue(const char* str, const char* flag, + bool def_optional) { + // str and flag must not be nullptr. + if (str == nullptr || flag == nullptr) return nullptr; + + // The flag must start with "--". + const std::string flag_str = std::string("--") + std::string(flag); + const size_t flag_len = flag_str.length(); + if (strncmp(str, flag_str.c_str(), flag_len) != 0) return nullptr; + + // Skips the flag name. + const char* flag_end = str + flag_len; + + // When def_optional is true, it's OK to not have a "=value" part. + if (def_optional && (flag_end[0] == '\0')) return flag_end; + + // If def_optional is true and there are more characters after the + // flag name, or if def_optional is false, there must be a '=' after + // the flag name. + if (flag_end[0] != '=') return nullptr; + + // Returns the string after "=". + return flag_end + 1; +} + +BENCHMARK_EXPORT +bool ParseBoolFlag(const char* str, const char* flag, bool* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, true); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Converts the string value to a bool. + *value = IsTruthyFlagValue(value_str); + return true; +} + +BENCHMARK_EXPORT +bool ParseInt32Flag(const char* str, const char* flag, int32_t* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Sets *value to the value of the flag. + return ParseInt32(std::string("The value of flag --") + flag, value_str, + value); +} + +BENCHMARK_EXPORT +bool ParseDoubleFlag(const char* str, const char* flag, double* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Sets *value to the value of the flag. + return ParseDouble(std::string("The value of flag --") + flag, value_str, + value); +} + +BENCHMARK_EXPORT +bool ParseStringFlag(const char* str, const char* flag, std::string* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + *value = value_str; + return true; +} + +BENCHMARK_EXPORT +bool ParseKeyValueFlag(const char* str, const char* flag, + std::map* value) { + const char* const value_str = ParseFlagValue(str, flag, false); + + if (value_str == nullptr) return false; + + for (const auto& kvpair : StrSplit(value_str, ',')) { + const auto kv = StrSplit(kvpair, '='); + if (kv.size() != 2) return false; + value->emplace(kv[0], kv[1]); + } + + return true; +} + +BENCHMARK_EXPORT +bool IsFlag(const char* str, const char* flag) { + return (ParseFlagValue(str, flag, true) != nullptr); +} + +BENCHMARK_EXPORT +bool IsTruthyFlagValue(const std::string& value) { + if (value.size() == 1) { + char v = value[0]; + return isalnum(v) && + !(v == '0' || v == 'f' || v == 'F' || v == 'n' || v == 'N'); + } + if (!value.empty()) { + std::string value_lower(value); + std::transform(value_lower.begin(), value_lower.end(), value_lower.begin(), + [](char c) { return static_cast(::tolower(c)); }); + return !(value_lower == "false" || value_lower == "no" || + value_lower == "off"); + } + return true; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/commandlineflags.h b/third_party/benchmark/src/commandlineflags.h new file mode 100644 index 0000000..7882628 --- /dev/null +++ b/third_party/benchmark/src/commandlineflags.h @@ -0,0 +1,133 @@ +#ifndef BENCHMARK_COMMANDLINEFLAGS_H_ +#define BENCHMARK_COMMANDLINEFLAGS_H_ + +#include +#include +#include + +#include "benchmark/export.h" + +// Macro for referencing flags. +#define FLAG(name) FLAGS_##name + +// Macros for declaring flags. +#define BM_DECLARE_bool(name) BENCHMARK_EXPORT extern bool FLAG(name) +#define BM_DECLARE_int32(name) BENCHMARK_EXPORT extern int32_t FLAG(name) +#define BM_DECLARE_double(name) BENCHMARK_EXPORT extern double FLAG(name) +#define BM_DECLARE_string(name) BENCHMARK_EXPORT extern std::string FLAG(name) +#define BM_DECLARE_kvpairs(name) \ + BENCHMARK_EXPORT extern std::map FLAG(name) + +// Macros for defining flags. +#define BM_DEFINE_bool(name, default_val) \ + BENCHMARK_EXPORT bool FLAG(name) = benchmark::BoolFromEnv(#name, default_val) +#define BM_DEFINE_int32(name, default_val) \ + BENCHMARK_EXPORT int32_t FLAG(name) = \ + benchmark::Int32FromEnv(#name, default_val) +#define BM_DEFINE_double(name, default_val) \ + BENCHMARK_EXPORT double FLAG(name) = \ + benchmark::DoubleFromEnv(#name, default_val) +#define BM_DEFINE_string(name, default_val) \ + BENCHMARK_EXPORT std::string FLAG(name) = \ + benchmark::StringFromEnv(#name, default_val) +#define BM_DEFINE_kvpairs(name, default_val) \ + BENCHMARK_EXPORT std::map FLAG(name) = \ + benchmark::KvPairsFromEnv(#name, default_val) + +namespace benchmark { + +// Parses a bool from the environment variable corresponding to the given flag. +// +// If the variable exists, returns IsTruthyFlagValue() value; if not, +// returns the given default value. +BENCHMARK_EXPORT +bool BoolFromEnv(const char* flag, bool default_val); + +// Parses an Int32 from the environment variable corresponding to the given +// flag. +// +// If the variable exists, returns ParseInt32() value; if not, returns +// the given default value. +BENCHMARK_EXPORT +int32_t Int32FromEnv(const char* flag, int32_t default_val); + +// Parses an Double from the environment variable corresponding to the given +// flag. +// +// If the variable exists, returns ParseDouble(); if not, returns +// the given default value. +BENCHMARK_EXPORT +double DoubleFromEnv(const char* flag, double default_val); + +// Parses a string from the environment variable corresponding to the given +// flag. +// +// If variable exists, returns its value; if not, returns +// the given default value. +BENCHMARK_EXPORT +const char* StringFromEnv(const char* flag, const char* default_val); + +// Parses a set of kvpairs from the environment variable corresponding to the +// given flag. +// +// If variable exists, returns its value; if not, returns +// the given default value. +BENCHMARK_EXPORT +std::map KvPairsFromEnv( + const char* flag, std::map default_val); + +// Parses a string for a bool flag, in the form of either +// "--flag=value" or "--flag". +// +// In the former case, the value is taken as true if it passes IsTruthyValue(). +// +// In the latter case, the value is taken as true. +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +BENCHMARK_EXPORT +bool ParseBoolFlag(const char* str, const char* flag, bool* value); + +// Parses a string for an Int32 flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +BENCHMARK_EXPORT +bool ParseInt32Flag(const char* str, const char* flag, int32_t* value); + +// Parses a string for a Double flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +BENCHMARK_EXPORT +bool ParseDoubleFlag(const char* str, const char* flag, double* value); + +// Parses a string for a string flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +BENCHMARK_EXPORT +bool ParseStringFlag(const char* str, const char* flag, std::string* value); + +// Parses a string for a kvpairs flag in the form "--flag=key=value,key=value" +// +// On success, stores the value of the flag in *value and returns true. On +// failure returns false, though *value may have been mutated. +BENCHMARK_EXPORT +bool ParseKeyValueFlag(const char* str, const char* flag, + std::map* value); + +// Returns true if the string matches the flag. +BENCHMARK_EXPORT +bool IsFlag(const char* str, const char* flag); + +// Returns true unless value starts with one of: '0', 'f', 'F', 'n' or 'N', or +// some non-alphanumeric character. Also returns false if the value matches +// one of 'no', 'false', 'off' (case-insensitive). As a special case, also +// returns true if value is the empty string. +BENCHMARK_EXPORT +bool IsTruthyFlagValue(const std::string& value); + +} // end namespace benchmark + +#endif // BENCHMARK_COMMANDLINEFLAGS_H_ diff --git a/third_party/benchmark/src/complexity.cc b/third_party/benchmark/src/complexity.cc new file mode 100644 index 0000000..63acd50 --- /dev/null +++ b/third_party/benchmark/src/complexity.cc @@ -0,0 +1,255 @@ +// Copyright 2016 Ismael Jimenez Martinez. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Source project : https://github.com/ismaelJimenez/cpp.leastsq +// Adapted to be used with google benchmark + +#include "complexity.h" + +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" + +namespace benchmark { + +// Internal function to calculate the different scalability forms +BigOFunc* FittingCurve(BigO complexity) { + switch (complexity) { + case oN: + return [](IterationCount n) -> double { return static_cast(n); }; + case oNSquared: + return [](IterationCount n) -> double { return std::pow(n, 2); }; + case oNCubed: + return [](IterationCount n) -> double { return std::pow(n, 3); }; + case oLogN: + return [](IterationCount n) -> double { + return std::log2(static_cast(n)); + }; + case oNLogN: + return [](IterationCount n) -> double { + return static_cast(n) * std::log2(static_cast(n)); + }; + case o1: + default: + return [](IterationCount) { return 1.0; }; + } +} + +// Function to return an string for the calculated complexity +std::string GetBigOString(BigO complexity) { + switch (complexity) { + case oN: + return "N"; + case oNSquared: + return "N^2"; + case oNCubed: + return "N^3"; + case oLogN: + return "lgN"; + case oNLogN: + return "NlgN"; + case o1: + return "(1)"; + default: + return "f(N)"; + } +} + +// Find the coefficient for the high-order term in the running time, by +// minimizing the sum of squares of relative error, for the fitting curve +// given by the lambda expression. +// - n : Vector containing the size of the benchmark tests. +// - time : Vector containing the times for the benchmark tests. +// - fitting_curve : lambda expression (e.g. [](ComplexityN n) {return n; };). + +// For a deeper explanation on the algorithm logic, please refer to +// https://en.wikipedia.org/wiki/Least_squares#Least_squares,_regression_analysis_and_statistics + +LeastSq MinimalLeastSq(const std::vector& n, + const std::vector& time, + BigOFunc* fitting_curve) { + double sigma_gn_squared = 0.0; + double sigma_time = 0.0; + double sigma_time_gn = 0.0; + + // Calculate least square fitting parameter + for (size_t i = 0; i < n.size(); ++i) { + double gn_i = fitting_curve(n[i]); + sigma_gn_squared += gn_i * gn_i; + sigma_time += time[i]; + sigma_time_gn += time[i] * gn_i; + } + + LeastSq result; + result.complexity = oLambda; + + // Calculate complexity. + result.coef = sigma_time_gn / sigma_gn_squared; + + // Calculate RMS + double rms = 0.0; + for (size_t i = 0; i < n.size(); ++i) { + double fit = result.coef * fitting_curve(n[i]); + rms += std::pow((time[i] - fit), 2); + } + + // Normalized RMS by the mean of the observed values + double mean = sigma_time / static_cast(n.size()); + result.rms = std::sqrt(rms / static_cast(n.size())) / mean; + + return result; +} + +// Find the coefficient for the high-order term in the running time, by +// minimizing the sum of squares of relative error. +// - n : Vector containing the size of the benchmark tests. +// - time : Vector containing the times for the benchmark tests. +// - complexity : If different than oAuto, the fitting curve will stick to +// this one. If it is oAuto, it will be calculated the best +// fitting curve. +LeastSq MinimalLeastSq(const std::vector& n, + const std::vector& time, const BigO complexity) { + BM_CHECK_EQ(n.size(), time.size()); + BM_CHECK_GE(n.size(), 2); // Do not compute fitting curve is less than two + // benchmark runs are given + BM_CHECK_NE(complexity, oNone); + + LeastSq best_fit; + + if (complexity == oAuto) { + std::vector fit_curves = {oLogN, oN, oNLogN, oNSquared, oNCubed}; + + // Take o1 as default best fitting curve + best_fit = MinimalLeastSq(n, time, FittingCurve(o1)); + best_fit.complexity = o1; + + // Compute all possible fitting curves and stick to the best one + for (const auto& fit : fit_curves) { + LeastSq current_fit = MinimalLeastSq(n, time, FittingCurve(fit)); + if (current_fit.rms < best_fit.rms) { + best_fit = current_fit; + best_fit.complexity = fit; + } + } + } else { + best_fit = MinimalLeastSq(n, time, FittingCurve(complexity)); + best_fit.complexity = complexity; + } + + return best_fit; +} + +std::vector ComputeBigO( + const std::vector& reports) { + typedef BenchmarkReporter::Run Run; + std::vector results; + + if (reports.size() < 2) return results; + + // Accumulators. + std::vector n; + std::vector real_time; + std::vector cpu_time; + + // Populate the accumulators. + for (const Run& run : reports) { + BM_CHECK_GT(run.complexity_n, 0) + << "Did you forget to call SetComplexityN?"; + n.push_back(run.complexity_n); + real_time.push_back(run.real_accumulated_time / + static_cast(run.iterations)); + cpu_time.push_back(run.cpu_accumulated_time / + static_cast(run.iterations)); + } + + LeastSq result_cpu; + LeastSq result_real; + + if (reports[0].complexity == oLambda) { + result_cpu = MinimalLeastSq(n, cpu_time, reports[0].complexity_lambda); + result_real = MinimalLeastSq(n, real_time, reports[0].complexity_lambda); + } else { + const BigO* InitialBigO = &reports[0].complexity; + const bool use_real_time_for_initial_big_o = + reports[0].use_real_time_for_initial_big_o; + if (use_real_time_for_initial_big_o) { + result_real = MinimalLeastSq(n, real_time, *InitialBigO); + InitialBigO = &result_real.complexity; + // The Big-O complexity for CPU time must have the same Big-O function! + } + result_cpu = MinimalLeastSq(n, cpu_time, *InitialBigO); + InitialBigO = &result_cpu.complexity; + if (!use_real_time_for_initial_big_o) { + result_real = MinimalLeastSq(n, real_time, *InitialBigO); + } + } + + // Drop the 'args' when reporting complexity. + auto run_name = reports[0].run_name; + run_name.args.clear(); + + // Get the data from the accumulator to BenchmarkReporter::Run's. + Run big_o; + big_o.run_name = run_name; + big_o.family_index = reports[0].family_index; + big_o.per_family_instance_index = reports[0].per_family_instance_index; + big_o.run_type = BenchmarkReporter::Run::RT_Aggregate; + big_o.repetitions = reports[0].repetitions; + big_o.repetition_index = Run::no_repetition_index; + big_o.threads = reports[0].threads; + big_o.aggregate_name = "BigO"; + big_o.aggregate_unit = StatisticUnit::kTime; + big_o.report_label = reports[0].report_label; + big_o.iterations = 0; + big_o.real_accumulated_time = result_real.coef; + big_o.cpu_accumulated_time = result_cpu.coef; + big_o.report_big_o = true; + big_o.complexity = result_cpu.complexity; + + // All the time results are reported after being multiplied by the + // time unit multiplier. But since RMS is a relative quantity it + // should not be multiplied at all. So, here, we _divide_ it by the + // multiplier so that when it is multiplied later the result is the + // correct one. + double multiplier = GetTimeUnitMultiplier(reports[0].time_unit); + + // Only add label to mean/stddev if it is same for all runs + Run rms; + rms.run_name = run_name; + rms.family_index = reports[0].family_index; + rms.per_family_instance_index = reports[0].per_family_instance_index; + rms.run_type = BenchmarkReporter::Run::RT_Aggregate; + rms.aggregate_name = "RMS"; + rms.aggregate_unit = StatisticUnit::kPercentage; + rms.report_label = big_o.report_label; + rms.iterations = 0; + rms.repetition_index = Run::no_repetition_index; + rms.repetitions = reports[0].repetitions; + rms.threads = reports[0].threads; + rms.real_accumulated_time = result_real.rms / multiplier; + rms.cpu_accumulated_time = result_cpu.rms / multiplier; + rms.report_rms = true; + rms.complexity = result_cpu.complexity; + // don't forget to keep the time unit, or we won't be able to + // recover the correct value. + rms.time_unit = reports[0].time_unit; + + results.push_back(big_o); + results.push_back(rms); + return results; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/complexity.h b/third_party/benchmark/src/complexity.h new file mode 100644 index 0000000..0a0679b --- /dev/null +++ b/third_party/benchmark/src/complexity.h @@ -0,0 +1,55 @@ +// Copyright 2016 Ismael Jimenez Martinez. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Source project : https://github.com/ismaelJimenez/cpp.leastsq +// Adapted to be used with google benchmark + +#ifndef COMPLEXITY_H_ +#define COMPLEXITY_H_ + +#include +#include + +#include "benchmark/benchmark.h" + +namespace benchmark { + +// Return a vector containing the bigO and RMS information for the specified +// list of reports. If 'reports.size() < 2' an empty vector is returned. +std::vector ComputeBigO( + const std::vector& reports); + +// This data structure will contain the result returned by MinimalLeastSq +// - coef : Estimated coefficient for the high-order term as +// interpolated from data. +// - rms : Normalized Root Mean Squared Error. +// - complexity : Scalability form (e.g. oN, oNLogN). In case a scalability +// form has been provided to MinimalLeastSq this will return +// the same value. In case BigO::oAuto has been selected, this +// parameter will return the best fitting curve detected. + +struct LeastSq { + LeastSq() : coef(0.0), rms(0.0), complexity(oNone) {} + + double coef; + double rms; + BigO complexity; +}; + +// Function to return an string for the calculated complexity +std::string GetBigOString(BigO complexity); + +} // end namespace benchmark + +#endif // COMPLEXITY_H_ diff --git a/third_party/benchmark/src/console_reporter.cc b/third_party/benchmark/src/console_reporter.cc new file mode 100644 index 0000000..35c3de2 --- /dev/null +++ b/third_party/benchmark/src/console_reporter.cc @@ -0,0 +1,210 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" +#include "colorprint.h" +#include "commandlineflags.h" +#include "complexity.h" +#include "counter.h" +#include "internal_macros.h" +#include "string_util.h" +#include "timers.h" + +namespace benchmark { + +BENCHMARK_EXPORT +bool ConsoleReporter::ReportContext(const Context& context) { + name_field_width_ = context.name_field_width; + printed_header_ = false; + prev_counters_.clear(); + + PrintBasicContext(&GetErrorStream(), context); + +#ifdef BENCHMARK_OS_WINDOWS + if ((output_options_ & OO_Color)) { + auto stdOutBuf = std::cout.rdbuf(); + auto outStreamBuf = GetOutputStream().rdbuf(); + if (stdOutBuf != outStreamBuf) { + GetErrorStream() + << "Color printing is only supported for stdout on windows." + " Disabling color printing\n"; + output_options_ = static_cast(output_options_ & ~OO_Color); + } + } +#endif + + return true; +} + +BENCHMARK_EXPORT +void ConsoleReporter::PrintHeader(const Run& run) { + std::string str = + FormatString("%-*s %13s %15s %12s", static_cast(name_field_width_), + "Benchmark", "Time", "CPU", "Iterations"); + if (!run.counters.empty()) { + if (output_options_ & OO_Tabular) { + for (auto const& c : run.counters) { + str += FormatString(" %10s", c.first.c_str()); + } + } else { + str += " UserCounters..."; + } + } + std::string line = std::string(str.length(), '-'); + GetOutputStream() << line << "\n" << str << "\n" << line << "\n"; +} + +BENCHMARK_EXPORT +void ConsoleReporter::ReportRuns(const std::vector& reports) { + for (const auto& run : reports) { + // print the header: + // --- if none was printed yet + bool print_header = !printed_header_; + // --- or if the format is tabular and this run + // has different fields from the prev header + print_header |= (output_options_ & OO_Tabular) && + (!internal::SameNames(run.counters, prev_counters_)); + if (print_header) { + printed_header_ = true; + prev_counters_ = run.counters; + PrintHeader(run); + } + // As an alternative to printing the headers like this, we could sort + // the benchmarks by header and then print. But this would require + // waiting for the full results before printing, or printing twice. + PrintRunData(run); + } +} + +static void IgnoreColorPrint(std::ostream& out, LogColor, const char* fmt, + ...) { + va_list args; + va_start(args, fmt); + out << FormatString(fmt, args); + va_end(args); +} + +static std::string FormatTime(double time) { + // For the time columns of the console printer 13 digits are reserved. One of + // them is a space and max two of them are the time unit (e.g ns). That puts + // us at 10 digits usable for the number. + // Align decimal places... + if (time < 1.0) { + return FormatString("%10.3f", time); + } + if (time < 10.0) { + return FormatString("%10.2f", time); + } + if (time < 100.0) { + return FormatString("%10.1f", time); + } + // Assuming the time is at max 9.9999e+99 and we have 10 digits for the + // number, we get 10-1(.)-1(e)-1(sign)-2(exponent) = 5 digits to print. + if (time > 9999999999 /*max 10 digit number*/) { + return FormatString("%1.4e", time); + } + return FormatString("%10.0f", time); +} + +BENCHMARK_EXPORT +void ConsoleReporter::PrintRunData(const Run& result) { + typedef void(PrinterFn)(std::ostream&, LogColor, const char*, ...); + auto& Out = GetOutputStream(); + PrinterFn* printer = (output_options_ & OO_Color) + ? static_cast(ColorPrintf) + : IgnoreColorPrint; + auto name_color = + (result.report_big_o || result.report_rms) ? COLOR_BLUE : COLOR_GREEN; + printer(Out, name_color, "%-*s ", name_field_width_, + result.benchmark_name().c_str()); + + if (internal::SkippedWithError == result.skipped) { + printer(Out, COLOR_RED, "ERROR OCCURRED: \'%s\'", + result.skip_message.c_str()); + printer(Out, COLOR_DEFAULT, "\n"); + return; + } else if (internal::SkippedWithMessage == result.skipped) { + printer(Out, COLOR_WHITE, "SKIPPED: \'%s\'", result.skip_message.c_str()); + printer(Out, COLOR_DEFAULT, "\n"); + return; + } + + const double real_time = result.GetAdjustedRealTime(); + const double cpu_time = result.GetAdjustedCPUTime(); + const std::string real_time_str = FormatTime(real_time); + const std::string cpu_time_str = FormatTime(cpu_time); + + if (result.report_big_o) { + std::string big_o = GetBigOString(result.complexity); + printer(Out, COLOR_YELLOW, "%10.2f %-4s %10.2f %-4s ", real_time, + big_o.c_str(), cpu_time, big_o.c_str()); + } else if (result.report_rms) { + printer(Out, COLOR_YELLOW, "%10.0f %-4s %10.0f %-4s ", real_time * 100, "%", + cpu_time * 100, "%"); + } else if (result.run_type != Run::RT_Aggregate || + result.aggregate_unit == StatisticUnit::kTime) { + const char* timeLabel = GetTimeUnitString(result.time_unit); + printer(Out, COLOR_YELLOW, "%s %-4s %s %-4s ", real_time_str.c_str(), + timeLabel, cpu_time_str.c_str(), timeLabel); + } else { + assert(result.aggregate_unit == StatisticUnit::kPercentage); + printer(Out, COLOR_YELLOW, "%10.2f %-4s %10.2f %-4s ", + (100. * result.real_accumulated_time), "%", + (100. * result.cpu_accumulated_time), "%"); + } + + if (!result.report_big_o && !result.report_rms) { + printer(Out, COLOR_CYAN, "%10lld", result.iterations); + } + + for (auto& c : result.counters) { + const std::size_t cNameLen = + std::max(std::string::size_type(10), c.first.length()); + std::string s; + const char* unit = ""; + if (result.run_type == Run::RT_Aggregate && + result.aggregate_unit == StatisticUnit::kPercentage) { + s = StrFormat("%.2f", 100. * c.second.value); + unit = "%"; + } else { + s = HumanReadableNumber(c.second.value, c.second.oneK); + if (c.second.flags & Counter::kIsRate) + unit = (c.second.flags & Counter::kInvert) ? "s" : "/s"; + } + if (output_options_ & OO_Tabular) { + printer(Out, COLOR_DEFAULT, " %*s%s", cNameLen - strlen(unit), s.c_str(), + unit); + } else { + printer(Out, COLOR_DEFAULT, " %s=%s%s", c.first.c_str(), s.c_str(), unit); + } + } + + if (!result.report_label.empty()) { + printer(Out, COLOR_DEFAULT, " %s", result.report_label.c_str()); + } + + printer(Out, COLOR_DEFAULT, "\n"); +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/counter.cc b/third_party/benchmark/src/counter.cc new file mode 100644 index 0000000..aa14cd8 --- /dev/null +++ b/third_party/benchmark/src/counter.cc @@ -0,0 +1,80 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "counter.h" + +namespace benchmark { +namespace internal { + +double Finish(Counter const& c, IterationCount iterations, double cpu_time, + double num_threads) { + double v = c.value; + if (c.flags & Counter::kIsRate) { + v /= cpu_time; + } + if (c.flags & Counter::kAvgThreads) { + v /= num_threads; + } + if (c.flags & Counter::kIsIterationInvariant) { + v *= static_cast(iterations); + } + if (c.flags & Counter::kAvgIterations) { + v /= static_cast(iterations); + } + + if (c.flags & Counter::kInvert) { // Invert is *always* last. + v = 1.0 / v; + } + return v; +} + +void Finish(UserCounters* l, IterationCount iterations, double cpu_time, + double num_threads) { + for (auto& c : *l) { + c.second.value = Finish(c.second, iterations, cpu_time, num_threads); + } +} + +void Increment(UserCounters* l, UserCounters const& r) { + // add counters present in both or just in *l + for (auto& c : *l) { + auto it = r.find(c.first); + if (it != r.end()) { + c.second.value = c.second + it->second; + } + } + // add counters present in r, but not in *l + for (auto const& tc : r) { + auto it = l->find(tc.first); + if (it == l->end()) { + (*l)[tc.first] = tc.second; + } + } +} + +bool SameNames(UserCounters const& l, UserCounters const& r) { + if (&l == &r) return true; + if (l.size() != r.size()) { + return false; + } + for (auto const& c : l) { + if (r.find(c.first) == r.end()) { + return false; + } + } + return true; +} + +} // end namespace internal +} // end namespace benchmark diff --git a/third_party/benchmark/src/counter.h b/third_party/benchmark/src/counter.h new file mode 100644 index 0000000..1f5a58e --- /dev/null +++ b/third_party/benchmark/src/counter.h @@ -0,0 +1,32 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BENCHMARK_COUNTER_H_ +#define BENCHMARK_COUNTER_H_ + +#include "benchmark/benchmark.h" + +namespace benchmark { + +// these counter-related functions are hidden to reduce API surface. +namespace internal { +void Finish(UserCounters* l, IterationCount iterations, double time, + double num_threads); +void Increment(UserCounters* l, UserCounters const& r); +bool SameNames(UserCounters const& l, UserCounters const& r); +} // end namespace internal + +} // end namespace benchmark + +#endif // BENCHMARK_COUNTER_H_ diff --git a/third_party/benchmark/src/csv_reporter.cc b/third_party/benchmark/src/csv_reporter.cc new file mode 100644 index 0000000..4b39e2c --- /dev/null +++ b/third_party/benchmark/src/csv_reporter.cc @@ -0,0 +1,169 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" +#include "complexity.h" +#include "string_util.h" +#include "timers.h" + +// File format reference: http://edoceo.com/utilitas/csv-file-format. + +namespace benchmark { + +namespace { +std::vector elements = { + "name", "iterations", "real_time", "cpu_time", + "time_unit", "bytes_per_second", "items_per_second", "label", + "error_occurred", "error_message"}; +} // namespace + +std::string CsvEscape(const std::string& s) { + std::string tmp; + tmp.reserve(s.size() + 2); + for (char c : s) { + switch (c) { + case '"': + tmp += "\"\""; + break; + default: + tmp += c; + break; + } + } + return '"' + tmp + '"'; +} + +BENCHMARK_EXPORT +bool CSVReporter::ReportContext(const Context& context) { + PrintBasicContext(&GetErrorStream(), context); + return true; +} + +BENCHMARK_EXPORT +void CSVReporter::ReportRuns(const std::vector& reports) { + std::ostream& Out = GetOutputStream(); + + if (!printed_header_) { + // save the names of all the user counters + for (const auto& run : reports) { + for (const auto& cnt : run.counters) { + if (cnt.first == "bytes_per_second" || cnt.first == "items_per_second") + continue; + user_counter_names_.insert(cnt.first); + } + } + + // print the header + for (auto B = elements.begin(); B != elements.end();) { + Out << *B++; + if (B != elements.end()) Out << ","; + } + for (auto B = user_counter_names_.begin(); + B != user_counter_names_.end();) { + Out << ",\"" << *B++ << "\""; + } + Out << "\n"; + + printed_header_ = true; + } else { + // check that all the current counters are saved in the name set + for (const auto& run : reports) { + for (const auto& cnt : run.counters) { + if (cnt.first == "bytes_per_second" || cnt.first == "items_per_second") + continue; + BM_CHECK(user_counter_names_.find(cnt.first) != + user_counter_names_.end()) + << "All counters must be present in each run. " + << "Counter named \"" << cnt.first + << "\" was not in a run after being added to the header"; + } + } + } + + // print results for each run + for (const auto& run : reports) { + PrintRunData(run); + } +} + +BENCHMARK_EXPORT +void CSVReporter::PrintRunData(const Run& run) { + std::ostream& Out = GetOutputStream(); + Out << CsvEscape(run.benchmark_name()) << ","; + if (run.skipped) { + Out << std::string(elements.size() - 3, ','); + Out << std::boolalpha << (internal::SkippedWithError == run.skipped) << ","; + Out << CsvEscape(run.skip_message) << "\n"; + return; + } + + // Do not print iteration on bigO and RMS report + if (!run.report_big_o && !run.report_rms) { + Out << run.iterations; + } + Out << ","; + + if (run.run_type != Run::RT_Aggregate || + run.aggregate_unit == StatisticUnit::kTime) { + Out << run.GetAdjustedRealTime() << ","; + Out << run.GetAdjustedCPUTime() << ","; + } else { + assert(run.aggregate_unit == StatisticUnit::kPercentage); + Out << run.real_accumulated_time << ","; + Out << run.cpu_accumulated_time << ","; + } + + // Do not print timeLabel on bigO and RMS report + if (run.report_big_o) { + Out << GetBigOString(run.complexity); + } else if (!run.report_rms && + run.aggregate_unit != StatisticUnit::kPercentage) { + Out << GetTimeUnitString(run.time_unit); + } + Out << ","; + + if (run.counters.find("bytes_per_second") != run.counters.end()) { + Out << run.counters.at("bytes_per_second"); + } + Out << ","; + if (run.counters.find("items_per_second") != run.counters.end()) { + Out << run.counters.at("items_per_second"); + } + Out << ","; + if (!run.report_label.empty()) { + Out << CsvEscape(run.report_label); + } + Out << ",,"; // for error_occurred and error_message + + // Print user counters + for (const auto& ucn : user_counter_names_) { + auto it = run.counters.find(ucn); + if (it == run.counters.end()) { + Out << ","; + } else { + Out << "," << it->second; + } + } + Out << '\n'; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/cycleclock.h b/third_party/benchmark/src/cycleclock.h new file mode 100644 index 0000000..bd62f5d --- /dev/null +++ b/third_party/benchmark/src/cycleclock.h @@ -0,0 +1,243 @@ +// ---------------------------------------------------------------------- +// CycleClock +// A CycleClock tells you the current time in Cycles. The "time" +// is actually time since power-on. This is like time() but doesn't +// involve a system call and is much more precise. +// +// NOTE: Not all cpu/platform/kernel combinations guarantee that this +// clock increments at a constant rate or is synchronized across all logical +// cpus in a system. +// +// If you need the above guarantees, please consider using a different +// API. There are efforts to provide an interface which provides a millisecond +// granularity and implemented as a memory read. A memory read is generally +// cheaper than the CycleClock for many architectures. +// +// Also, in some out of order CPU implementations, the CycleClock is not +// serializing. So if you're trying to count at cycles granularity, your +// data might be inaccurate due to out of order instruction execution. +// ---------------------------------------------------------------------- + +#ifndef BENCHMARK_CYCLECLOCK_H_ +#define BENCHMARK_CYCLECLOCK_H_ + +#include + +#include "benchmark/benchmark.h" +#include "internal_macros.h" + +#if defined(BENCHMARK_OS_MACOSX) +#include +#endif +// For MSVC, we want to use '_asm rdtsc' when possible (since it works +// with even ancient MSVC compilers), and when not possible the +// __rdtsc intrinsic, declared in . Unfortunately, in some +// environments, and have conflicting +// declarations of some other intrinsics, breaking compilation. +// Therefore, we simply declare __rdtsc ourselves. See also +// http://connect.microsoft.com/VisualStudio/feedback/details/262047 +#if defined(COMPILER_MSVC) && !defined(_M_IX86) && !defined(_M_ARM64) && \ + !defined(_M_ARM64EC) +extern "C" uint64_t __rdtsc(); +#pragma intrinsic(__rdtsc) +#endif + +#if !defined(BENCHMARK_OS_WINDOWS) || defined(BENCHMARK_OS_MINGW) +#include +#include +#endif + +#ifdef BENCHMARK_OS_EMSCRIPTEN +#include +#endif + +namespace benchmark { +// NOTE: only i386 and x86_64 have been well tested. +// PPC, sparc, alpha, and ia64 are based on +// http://peter.kuscsik.com/wordpress/?p=14 +// with modifications by m3b. See also +// https://setisvn.ssl.berkeley.edu/svn/lib/fftw-3.0.1/kernel/cycle.h +namespace cycleclock { +// This should return the number of cycles since power-on. Thread-safe. +inline BENCHMARK_ALWAYS_INLINE int64_t Now() { +#if defined(BENCHMARK_OS_MACOSX) + // this goes at the top because we need ALL Macs, regardless of + // architecture, to return the number of "mach time units" that + // have passed since startup. See sysinfo.cc where + // InitializeSystemInfo() sets the supposed cpu clock frequency of + // macs to the number of mach time units per second, not actual + // CPU clock frequency (which can change in the face of CPU + // frequency scaling). Also note that when the Mac sleeps, this + // counter pauses; it does not continue counting, nor does it + // reset to zero. + return static_cast(mach_absolute_time()); +#elif defined(BENCHMARK_OS_EMSCRIPTEN) + // this goes above x86-specific code because old versions of Emscripten + // define __x86_64__, although they have nothing to do with it. + return static_cast(emscripten_get_now() * 1e+6); +#elif defined(__i386__) + int64_t ret; + __asm__ volatile("rdtsc" : "=A"(ret)); + return ret; +#elif defined(__x86_64__) || defined(__amd64__) + uint64_t low, high; + __asm__ volatile("rdtsc" : "=a"(low), "=d"(high)); + return static_cast((high << 32) | low); +#elif defined(__powerpc__) || defined(__ppc__) + // This returns a time-base, which is not always precisely a cycle-count. +#if defined(__powerpc64__) || defined(__ppc64__) + int64_t tb; + asm volatile("mfspr %0, 268" : "=r"(tb)); + return tb; +#else + uint32_t tbl, tbu0, tbu1; + asm volatile( + "mftbu %0\n" + "mftb %1\n" + "mftbu %2" + : "=r"(tbu0), "=r"(tbl), "=r"(tbu1)); + tbl &= -static_cast(tbu0 == tbu1); + // high 32 bits in tbu1; low 32 bits in tbl (tbu0 is no longer needed) + return (static_cast(tbu1) << 32) | tbl; +#endif +#elif defined(__sparc__) + int64_t tick; + asm(".byte 0x83, 0x41, 0x00, 0x00"); + asm("mov %%g1, %0" : "=r"(tick)); + return tick; +#elif defined(__ia64__) + int64_t itc; + asm("mov %0 = ar.itc" : "=r"(itc)); + return itc; +#elif defined(COMPILER_MSVC) && defined(_M_IX86) + // Older MSVC compilers (like 7.x) don't seem to support the + // __rdtsc intrinsic properly, so I prefer to use _asm instead + // when I know it will work. Otherwise, I'll use __rdtsc and hope + // the code is being compiled with a non-ancient compiler. + _asm rdtsc +#elif defined(COMPILER_MSVC) && (defined(_M_ARM64) || defined(_M_ARM64EC)) + // See // https://docs.microsoft.com/en-us/cpp/intrinsics/arm64-intrinsics + // and https://reviews.llvm.org/D53115 + int64_t virtual_timer_value; + virtual_timer_value = _ReadStatusReg(ARM64_CNTVCT); + return virtual_timer_value; +#elif defined(COMPILER_MSVC) + return __rdtsc(); +#elif defined(BENCHMARK_OS_NACL) + // Native Client validator on x86/x86-64 allows RDTSC instructions, + // and this case is handled above. Native Client validator on ARM + // rejects MRC instructions (used in the ARM-specific sequence below), + // so we handle it here. Portable Native Client compiles to + // architecture-agnostic bytecode, which doesn't provide any + // cycle counter access mnemonics. + + // Native Client does not provide any API to access cycle counter. + // Use clock_gettime(CLOCK_MONOTONIC, ...) instead of gettimeofday + // because is provides nanosecond resolution (which is noticeable at + // least for PNaCl modules running on x86 Mac & Linux). + // Initialize to always return 0 if clock_gettime fails. + struct timespec ts = {0, 0}; + clock_gettime(CLOCK_MONOTONIC, &ts); + return static_cast(ts.tv_sec) * 1000000000 + ts.tv_nsec; +#elif defined(__aarch64__) + // System timer of ARMv8 runs at a different frequency than the CPU's. + // The frequency is fixed, typically in the range 1-50MHz. It can be + // read at CNTFRQ special register. We assume the OS has set up + // the virtual timer properly. + int64_t virtual_timer_value; + asm volatile("mrs %0, cntvct_el0" : "=r"(virtual_timer_value)); + return virtual_timer_value; +#elif defined(__ARM_ARCH) + // V6 is the earliest arch that has a standard cyclecount + // Native Client validator doesn't allow MRC instructions. +#if (__ARM_ARCH >= 6) + uint32_t pmccntr; + uint32_t pmuseren; + uint32_t pmcntenset; + // Read the user mode perf monitor counter access permissions. + asm volatile("mrc p15, 0, %0, c9, c14, 0" : "=r"(pmuseren)); + if (pmuseren & 1) { // Allows reading perfmon counters for user mode code. + asm volatile("mrc p15, 0, %0, c9, c12, 1" : "=r"(pmcntenset)); + if (pmcntenset & 0x80000000ul) { // Is it counting? + asm volatile("mrc p15, 0, %0, c9, c13, 0" : "=r"(pmccntr)); + // The counter is set up to count every 64th cycle + return static_cast(pmccntr) * 64; // Should optimize to << 6 + } + } +#endif + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +#elif defined(__mips__) || defined(__m68k__) + // mips apparently only allows rdtsc for superusers, so we fall + // back to gettimeofday. It's possible clock_gettime would be better. + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +#elif defined(__loongarch__) || defined(__csky__) + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +#elif defined(__s390__) // Covers both s390 and s390x. + // Return the CPU clock. + uint64_t tsc; +#if defined(BENCHMARK_OS_ZOS) + // z/OS HLASM syntax. + asm(" stck %0" : "=m"(tsc) : : "cc"); +#else + // Linux on Z syntax. + asm("stck %0" : "=Q"(tsc) : : "cc"); +#endif + return tsc; +#elif defined(__riscv) // RISC-V + // Use RDTIME (and RDTIMEH on riscv32). + // RDCYCLE is a privileged instruction since Linux 6.6. +#if __riscv_xlen == 32 + uint32_t cycles_lo, cycles_hi0, cycles_hi1; + // This asm also includes the PowerPC overflow handling strategy, as above. + // Implemented in assembly because Clang insisted on branching. + asm volatile( + "rdtimeh %0\n" + "rdtime %1\n" + "rdtimeh %2\n" + "sub %0, %0, %2\n" + "seqz %0, %0\n" + "sub %0, zero, %0\n" + "and %1, %1, %0\n" + : "=r"(cycles_hi0), "=r"(cycles_lo), "=r"(cycles_hi1)); + return static_cast((static_cast(cycles_hi1) << 32) | + cycles_lo); +#else + uint64_t cycles; + asm volatile("rdtime %0" : "=r"(cycles)); + return static_cast(cycles); +#endif +#elif defined(__e2k__) || defined(__elbrus__) + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +#elif defined(__hexagon__) + uint64_t pcycle; + asm volatile("%0 = C15:14" : "=r"(pcycle)); + return static_cast(pcycle); +#elif defined(__alpha__) + // Alpha has a cycle counter, the PCC register, but it is an unsigned 32-bit + // integer and thus wraps every ~4s, making using it for tick counts + // unreliable beyond this time range. The real-time clock is low-precision, + // roughtly ~1ms, but it is the only option that can reasonable count + // indefinitely. + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +#else + // The soft failover to a generic implementation is automatic only for ARM. + // For other platforms the developer is expected to make an attempt to create + // a fast implementation and use generic version if nothing better is + // available. +#error You need to define CycleTimer for your OS and CPU +#endif +} +} // end namespace cycleclock +} // end namespace benchmark + +#endif // BENCHMARK_CYCLECLOCK_H_ diff --git a/third_party/benchmark/src/internal_macros.h b/third_party/benchmark/src/internal_macros.h new file mode 100644 index 0000000..f4894ba --- /dev/null +++ b/third_party/benchmark/src/internal_macros.h @@ -0,0 +1,111 @@ +#ifndef BENCHMARK_INTERNAL_MACROS_H_ +#define BENCHMARK_INTERNAL_MACROS_H_ + +/* Needed to detect STL */ +#include + +// clang-format off + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if defined(__clang__) + #if !defined(COMPILER_CLANG) + #define COMPILER_CLANG + #endif +#elif defined(_MSC_VER) + #if !defined(COMPILER_MSVC) + #define COMPILER_MSVC + #endif +#elif defined(__GNUC__) + #if !defined(COMPILER_GCC) + #define COMPILER_GCC + #endif +#endif + +#if __has_feature(cxx_attributes) + #define BENCHMARK_NORETURN [[noreturn]] +#elif defined(__GNUC__) + #define BENCHMARK_NORETURN __attribute__((noreturn)) +#elif defined(COMPILER_MSVC) + #define BENCHMARK_NORETURN __declspec(noreturn) +#else + #define BENCHMARK_NORETURN +#endif + +#if defined(__CYGWIN__) + #define BENCHMARK_OS_CYGWIN 1 +#elif defined(_WIN32) + #define BENCHMARK_OS_WINDOWS 1 + // WINAPI_FAMILY_PARTITION is defined in winapifamily.h. + // We include windows.h which implicitly includes winapifamily.h for compatibility. + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #if defined(WINAPI_FAMILY_PARTITION) + #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + #define BENCHMARK_OS_WINDOWS_WIN32 1 + #elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) + #define BENCHMARK_OS_WINDOWS_RT 1 + #endif + #endif + #if defined(__MINGW32__) + #define BENCHMARK_OS_MINGW 1 + #endif +#elif defined(__APPLE__) + #define BENCHMARK_OS_APPLE 1 + #include "TargetConditionals.h" + #if defined(TARGET_OS_MAC) + #define BENCHMARK_OS_MACOSX 1 + #if defined(TARGET_OS_IPHONE) + #define BENCHMARK_OS_IOS 1 + #endif + #endif +#elif defined(__FreeBSD__) + #define BENCHMARK_OS_FREEBSD 1 +#elif defined(__NetBSD__) + #define BENCHMARK_OS_NETBSD 1 +#elif defined(__OpenBSD__) + #define BENCHMARK_OS_OPENBSD 1 +#elif defined(__DragonFly__) + #define BENCHMARK_OS_DRAGONFLY 1 +#elif defined(__linux__) + #define BENCHMARK_OS_LINUX 1 +#elif defined(__native_client__) + #define BENCHMARK_OS_NACL 1 +#elif defined(__EMSCRIPTEN__) + #define BENCHMARK_OS_EMSCRIPTEN 1 +#elif defined(__rtems__) + #define BENCHMARK_OS_RTEMS 1 +#elif defined(__Fuchsia__) +#define BENCHMARK_OS_FUCHSIA 1 +#elif defined (__SVR4) && defined (__sun) +#define BENCHMARK_OS_SOLARIS 1 +#elif defined(__QNX__) +#define BENCHMARK_OS_QNX 1 +#elif defined(__MVS__) +#define BENCHMARK_OS_ZOS 1 +#elif defined(__hexagon__) +#define BENCHMARK_OS_QURT 1 +#endif + +#if defined(__ANDROID__) && defined(__GLIBCXX__) +#define BENCHMARK_STL_ANDROID_GNUSTL 1 +#endif + +#if !__has_feature(cxx_exceptions) && !defined(__cpp_exceptions) \ + && !defined(__EXCEPTIONS) + #define BENCHMARK_HAS_NO_EXCEPTIONS +#endif + +#if defined(COMPILER_CLANG) || defined(COMPILER_GCC) + #define BENCHMARK_MAYBE_UNUSED __attribute__((unused)) +#else + #define BENCHMARK_MAYBE_UNUSED +#endif + +// clang-format on + +#endif // BENCHMARK_INTERNAL_MACROS_H_ diff --git a/third_party/benchmark/src/json_reporter.cc b/third_party/benchmark/src/json_reporter.cc new file mode 100644 index 0000000..b8c8c94 --- /dev/null +++ b/third_party/benchmark/src/json_reporter.cc @@ -0,0 +1,327 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include // for setprecision +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "complexity.h" +#include "string_util.h" +#include "timers.h" + +namespace benchmark { +namespace { + +std::string StrEscape(const std::string& s) { + std::string tmp; + tmp.reserve(s.size()); + for (char c : s) { + switch (c) { + case '\b': + tmp += "\\b"; + break; + case '\f': + tmp += "\\f"; + break; + case '\n': + tmp += "\\n"; + break; + case '\r': + tmp += "\\r"; + break; + case '\t': + tmp += "\\t"; + break; + case '\\': + tmp += "\\\\"; + break; + case '"': + tmp += "\\\""; + break; + default: + tmp += c; + break; + } + } + return tmp; +} + +std::string FormatKV(std::string const& key, std::string const& value) { + return StrFormat("\"%s\": \"%s\"", StrEscape(key).c_str(), + StrEscape(value).c_str()); +} + +std::string FormatKV(std::string const& key, const char* value) { + return StrFormat("\"%s\": \"%s\"", StrEscape(key).c_str(), + StrEscape(value).c_str()); +} + +std::string FormatKV(std::string const& key, bool value) { + return StrFormat("\"%s\": %s", StrEscape(key).c_str(), + value ? "true" : "false"); +} + +std::string FormatKV(std::string const& key, int64_t value) { + std::stringstream ss; + ss << '"' << StrEscape(key) << "\": " << value; + return ss.str(); +} + +std::string FormatKV(std::string const& key, double value) { + std::stringstream ss; + ss << '"' << StrEscape(key) << "\": "; + + if (std::isnan(value)) + ss << (value < 0 ? "-" : "") << "NaN"; + else if (std::isinf(value)) + ss << (value < 0 ? "-" : "") << "Infinity"; + else { + const auto max_digits10 = + std::numeric_limits::max_digits10; + const auto max_fractional_digits10 = max_digits10 - 1; + ss << std::scientific << std::setprecision(max_fractional_digits10) + << value; + } + return ss.str(); +} + +int64_t RoundDouble(double v) { return std::lround(v); } + +} // end namespace + +bool JSONReporter::ReportContext(const Context& context) { + std::ostream& out = GetOutputStream(); + + out << "{\n"; + std::string inner_indent(2, ' '); + + // Open context block and print context information. + out << inner_indent << "\"context\": {\n"; + std::string indent(4, ' '); + + std::string walltime_value = LocalDateTimeString(); + out << indent << FormatKV("date", walltime_value) << ",\n"; + + out << indent << FormatKV("host_name", context.sys_info.name) << ",\n"; + + if (Context::executable_name) { + out << indent << FormatKV("executable", Context::executable_name) << ",\n"; + } + + CPUInfo const& info = context.cpu_info; + out << indent << FormatKV("num_cpus", static_cast(info.num_cpus)) + << ",\n"; + out << indent + << FormatKV("mhz_per_cpu", + RoundDouble(info.cycles_per_second / 1000000.0)) + << ",\n"; + if (CPUInfo::Scaling::UNKNOWN != info.scaling) { + out << indent + << FormatKV("cpu_scaling_enabled", + info.scaling == CPUInfo::Scaling::ENABLED ? true : false) + << ",\n"; + } + + out << indent << "\"caches\": [\n"; + indent = std::string(6, ' '); + std::string cache_indent(8, ' '); + for (size_t i = 0; i < info.caches.size(); ++i) { + auto& CI = info.caches[i]; + out << indent << "{\n"; + out << cache_indent << FormatKV("type", CI.type) << ",\n"; + out << cache_indent << FormatKV("level", static_cast(CI.level)) + << ",\n"; + out << cache_indent << FormatKV("size", static_cast(CI.size)) + << ",\n"; + out << cache_indent + << FormatKV("num_sharing", static_cast(CI.num_sharing)) + << "\n"; + out << indent << "}"; + if (i != info.caches.size() - 1) out << ","; + out << "\n"; + } + indent = std::string(4, ' '); + out << indent << "],\n"; + out << indent << "\"load_avg\": ["; + for (auto it = info.load_avg.begin(); it != info.load_avg.end();) { + out << *it++; + if (it != info.load_avg.end()) out << ","; + } + out << "],\n"; + + out << indent << FormatKV("library_version", GetBenchmarkVersion()); + out << ",\n"; + +#if defined(NDEBUG) + const char build_type[] = "release"; +#else + const char build_type[] = "debug"; +#endif + out << indent << FormatKV("library_build_type", build_type); + out << ",\n"; + + // NOTE: our json schema is not strictly tied to the library version! + out << indent << FormatKV("json_schema_version", int64_t(1)); + + std::map* global_context = + internal::GetGlobalContext(); + + if (global_context != nullptr) { + for (const auto& kv : *global_context) { + out << ",\n"; + out << indent << FormatKV(kv.first, kv.second); + } + } + out << "\n"; + + // Close context block and open the list of benchmarks. + out << inner_indent << "},\n"; + out << inner_indent << "\"benchmarks\": [\n"; + return true; +} + +void JSONReporter::ReportRuns(std::vector const& reports) { + if (reports.empty()) { + return; + } + std::string indent(4, ' '); + std::ostream& out = GetOutputStream(); + if (!first_report_) { + out << ",\n"; + } + first_report_ = false; + + for (auto it = reports.begin(); it != reports.end(); ++it) { + out << indent << "{\n"; + PrintRunData(*it); + out << indent << '}'; + auto it_cp = it; + if (++it_cp != reports.end()) { + out << ",\n"; + } + } +} + +void JSONReporter::Finalize() { + // Close the list of benchmarks and the top level object. + GetOutputStream() << "\n ]\n}\n"; +} + +void JSONReporter::PrintRunData(Run const& run) { + std::string indent(6, ' '); + std::ostream& out = GetOutputStream(); + out << indent << FormatKV("name", run.benchmark_name()) << ",\n"; + out << indent << FormatKV("family_index", run.family_index) << ",\n"; + out << indent + << FormatKV("per_family_instance_index", run.per_family_instance_index) + << ",\n"; + out << indent << FormatKV("run_name", run.run_name.str()) << ",\n"; + out << indent << FormatKV("run_type", [&run]() -> const char* { + switch (run.run_type) { + case BenchmarkReporter::Run::RT_Iteration: + return "iteration"; + case BenchmarkReporter::Run::RT_Aggregate: + return "aggregate"; + } + BENCHMARK_UNREACHABLE(); + }()) << ",\n"; + out << indent << FormatKV("repetitions", run.repetitions) << ",\n"; + if (run.run_type != BenchmarkReporter::Run::RT_Aggregate) { + out << indent << FormatKV("repetition_index", run.repetition_index) + << ",\n"; + } + out << indent << FormatKV("threads", run.threads) << ",\n"; + if (run.run_type == BenchmarkReporter::Run::RT_Aggregate) { + out << indent << FormatKV("aggregate_name", run.aggregate_name) << ",\n"; + out << indent << FormatKV("aggregate_unit", [&run]() -> const char* { + switch (run.aggregate_unit) { + case StatisticUnit::kTime: + return "time"; + case StatisticUnit::kPercentage: + return "percentage"; + } + BENCHMARK_UNREACHABLE(); + }()) << ",\n"; + } + if (internal::SkippedWithError == run.skipped) { + out << indent << FormatKV("error_occurred", true) << ",\n"; + out << indent << FormatKV("error_message", run.skip_message) << ",\n"; + } else if (internal::SkippedWithMessage == run.skipped) { + out << indent << FormatKV("skipped", true) << ",\n"; + out << indent << FormatKV("skip_message", run.skip_message) << ",\n"; + } + if (!run.report_big_o && !run.report_rms) { + out << indent << FormatKV("iterations", run.iterations) << ",\n"; + if (run.run_type != Run::RT_Aggregate || + run.aggregate_unit == StatisticUnit::kTime) { + out << indent << FormatKV("real_time", run.GetAdjustedRealTime()) + << ",\n"; + out << indent << FormatKV("cpu_time", run.GetAdjustedCPUTime()); + } else { + assert(run.aggregate_unit == StatisticUnit::kPercentage); + out << indent << FormatKV("real_time", run.real_accumulated_time) + << ",\n"; + out << indent << FormatKV("cpu_time", run.cpu_accumulated_time); + } + out << ",\n" + << indent << FormatKV("time_unit", GetTimeUnitString(run.time_unit)); + } else if (run.report_big_o) { + out << indent << FormatKV("cpu_coefficient", run.GetAdjustedCPUTime()) + << ",\n"; + out << indent << FormatKV("real_coefficient", run.GetAdjustedRealTime()) + << ",\n"; + out << indent << FormatKV("big_o", GetBigOString(run.complexity)) << ",\n"; + out << indent << FormatKV("time_unit", GetTimeUnitString(run.time_unit)); + } else if (run.report_rms) { + out << indent << FormatKV("rms", run.GetAdjustedCPUTime()); + } + + for (auto& c : run.counters) { + out << ",\n" << indent << FormatKV(c.first, c.second); + } + + if (run.memory_result) { + const MemoryManager::Result memory_result = *run.memory_result; + out << ",\n" << indent << FormatKV("allocs_per_iter", run.allocs_per_iter); + out << ",\n" + << indent << FormatKV("max_bytes_used", memory_result.max_bytes_used); + + auto report_if_present = [&out, &indent](const std::string& label, + int64_t val) { + if (val != MemoryManager::TombstoneValue) + out << ",\n" << indent << FormatKV(label, val); + }; + + report_if_present("total_allocated_bytes", + memory_result.total_allocated_bytes); + report_if_present("net_heap_growth", memory_result.net_heap_growth); + } + + if (!run.report_label.empty()) { + out << ",\n" << indent << FormatKV("label", run.report_label); + } + out << '\n'; +} + +const int64_t MemoryManager::TombstoneValue = + std::numeric_limits::max(); + +} // end namespace benchmark diff --git a/third_party/benchmark/src/log.h b/third_party/benchmark/src/log.h new file mode 100644 index 0000000..9a21400 --- /dev/null +++ b/third_party/benchmark/src/log.h @@ -0,0 +1,88 @@ +#ifndef BENCHMARK_LOG_H_ +#define BENCHMARK_LOG_H_ + +#include +#include + +// NOTE: this is also defined in benchmark.h but we're trying to avoid a +// dependency. +// The _MSVC_LANG check should detect Visual Studio 2015 Update 3 and newer. +#if __cplusplus >= 201103L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201103L) +#define BENCHMARK_HAS_CXX11 +#endif + +namespace benchmark { +namespace internal { + +typedef std::basic_ostream&(EndLType)(std::basic_ostream&); + +class LogType { + friend LogType& GetNullLogInstance(); + friend LogType& GetErrorLogInstance(); + + // FIXME: Add locking to output. + template + friend LogType& operator<<(LogType&, Tp const&); + friend LogType& operator<<(LogType&, EndLType*); + + private: + LogType(std::ostream* out) : out_(out) {} + std::ostream* out_; + + // NOTE: we could use BENCHMARK_DISALLOW_COPY_AND_ASSIGN but we shouldn't have + // a dependency on benchmark.h from here. +#ifndef BENCHMARK_HAS_CXX11 + LogType(const LogType&); + LogType& operator=(const LogType&); +#else + LogType(const LogType&) = delete; + LogType& operator=(const LogType&) = delete; +#endif +}; + +template +LogType& operator<<(LogType& log, Tp const& value) { + if (log.out_) { + *log.out_ << value; + } + return log; +} + +inline LogType& operator<<(LogType& log, EndLType* m) { + if (log.out_) { + *log.out_ << m; + } + return log; +} + +inline int& LogLevel() { + static int log_level = 0; + return log_level; +} + +inline LogType& GetNullLogInstance() { + static LogType null_log(static_cast(nullptr)); + return null_log; +} + +inline LogType& GetErrorLogInstance() { + static LogType error_log(&std::clog); + return error_log; +} + +inline LogType& GetLogInstanceForLevel(int level) { + if (level <= LogLevel()) { + return GetErrorLogInstance(); + } + return GetNullLogInstance(); +} + +} // end namespace internal +} // end namespace benchmark + +// clang-format off +#define BM_VLOG(x) \ + (::benchmark::internal::GetLogInstanceForLevel(x) << "-- LOG(" << x << "):" \ + " ") +// clang-format on +#endif diff --git a/third_party/benchmark/src/mutex.h b/third_party/benchmark/src/mutex.h new file mode 100644 index 0000000..bec78d9 --- /dev/null +++ b/third_party/benchmark/src/mutex.h @@ -0,0 +1,155 @@ +#ifndef BENCHMARK_MUTEX_H_ +#define BENCHMARK_MUTEX_H_ + +#include +#include + +#include "check.h" + +// Enable thread safety attributes only with clang. +// The attributes can be safely erased when compiling with other compilers. +#if defined(HAVE_THREAD_SAFETY_ATTRIBUTES) +#define THREAD_ANNOTATION_ATTRIBUTE_(x) __attribute__((x)) +#else +#define THREAD_ANNOTATION_ATTRIBUTE_(x) // no-op +#endif + +#define CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE_(capability(x)) + +#define SCOPED_CAPABILITY THREAD_ANNOTATION_ATTRIBUTE_(scoped_lockable) + +#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE_(guarded_by(x)) + +#define PT_GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE_(pt_guarded_by(x)) + +#define ACQUIRED_BEFORE(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(acquired_before(__VA_ARGS__)) + +#define ACQUIRED_AFTER(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(acquired_after(__VA_ARGS__)) + +#define REQUIRES(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(requires_capability(__VA_ARGS__)) + +#define REQUIRES_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(requires_shared_capability(__VA_ARGS__)) + +#define ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(acquire_capability(__VA_ARGS__)) + +#define ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(acquire_shared_capability(__VA_ARGS__)) + +#define RELEASE(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(release_capability(__VA_ARGS__)) + +#define RELEASE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(release_shared_capability(__VA_ARGS__)) + +#define TRY_ACQUIRE(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(try_acquire_capability(__VA_ARGS__)) + +#define TRY_ACQUIRE_SHARED(...) \ + THREAD_ANNOTATION_ATTRIBUTE_(try_acquire_shared_capability(__VA_ARGS__)) + +#define EXCLUDES(...) THREAD_ANNOTATION_ATTRIBUTE_(locks_excluded(__VA_ARGS__)) + +#define ASSERT_CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE_(assert_capability(x)) + +#define ASSERT_SHARED_CAPABILITY(x) \ + THREAD_ANNOTATION_ATTRIBUTE_(assert_shared_capability(x)) + +#define RETURN_CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE_(lock_returned(x)) + +#define NO_THREAD_SAFETY_ANALYSIS \ + THREAD_ANNOTATION_ATTRIBUTE_(no_thread_safety_analysis) + +namespace benchmark { + +typedef std::condition_variable Condition; + +// NOTE: Wrappers for std::mutex and std::unique_lock are provided so that +// we can annotate them with thread safety attributes and use the +// -Wthread-safety warning with clang. The standard library types cannot be +// used directly because they do not provide the required annotations. +class CAPABILITY("mutex") Mutex { + public: + Mutex() {} + + void lock() ACQUIRE() { mut_.lock(); } + void unlock() RELEASE() { mut_.unlock(); } + std::mutex& native_handle() { return mut_; } + + private: + std::mutex mut_; +}; + +class SCOPED_CAPABILITY MutexLock { + typedef std::unique_lock MutexLockImp; + + public: + MutexLock(Mutex& m) ACQUIRE(m) : ml_(m.native_handle()) {} + ~MutexLock() RELEASE() {} + MutexLockImp& native_handle() { return ml_; } + + private: + MutexLockImp ml_; +}; + +class Barrier { + public: + Barrier(int num_threads) : running_threads_(num_threads) {} + + // Called by each thread + bool wait() EXCLUDES(lock_) { + bool last_thread = false; + { + MutexLock ml(lock_); + last_thread = createBarrier(ml); + } + if (last_thread) phase_condition_.notify_all(); + return last_thread; + } + + void removeThread() EXCLUDES(lock_) { + MutexLock ml(lock_); + --running_threads_; + if (entered_ != 0) phase_condition_.notify_all(); + } + + private: + Mutex lock_; + Condition phase_condition_; + int running_threads_; + + // State for barrier management + int phase_number_ = 0; + int entered_ = 0; // Number of threads that have entered this barrier + + // Enter the barrier and wait until all other threads have also + // entered the barrier. Returns iff this is the last thread to + // enter the barrier. + bool createBarrier(MutexLock& ml) REQUIRES(lock_) { + BM_CHECK_LT(entered_, running_threads_); + entered_++; + if (entered_ < running_threads_) { + // Wait for all threads to enter + int phase_number_cp = phase_number_; + auto cb = [this, phase_number_cp]() { + return this->phase_number_ > phase_number_cp || + entered_ == running_threads_; // A thread has aborted in error + }; + phase_condition_.wait(ml.native_handle(), cb); + if (phase_number_ > phase_number_cp) return false; + // else (running_threads_ == entered_) and we are the last thread. + } + // Last thread has reached the barrier + phase_number_++; + entered_ = 0; + return true; + } +}; + +} // end namespace benchmark + +#endif // BENCHMARK_MUTEX_H_ diff --git a/third_party/benchmark/src/perf_counters.cc b/third_party/benchmark/src/perf_counters.cc new file mode 100644 index 0000000..a2fa7fe --- /dev/null +++ b/third_party/benchmark/src/perf_counters.cc @@ -0,0 +1,282 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "perf_counters.h" + +#include +#include +#include + +#if defined HAVE_LIBPFM +#include "perfmon/pfmlib.h" +#include "perfmon/pfmlib_perf_event.h" +#endif + +namespace benchmark { +namespace internal { + +#if defined HAVE_LIBPFM + +size_t PerfCounterValues::Read(const std::vector& leaders) { + // Create a pointer for multiple reads + const size_t bufsize = values_.size() * sizeof(values_[0]); + char* ptr = reinterpret_cast(values_.data()); + size_t size = bufsize; + for (int lead : leaders) { + auto read_bytes = ::read(lead, ptr, size); + if (read_bytes >= ssize_t(sizeof(uint64_t))) { + // Actual data bytes are all bytes minus initial padding + std::size_t data_bytes = + static_cast(read_bytes) - sizeof(uint64_t); + // This should be very cheap since it's in hot cache + std::memmove(ptr, ptr + sizeof(uint64_t), data_bytes); + // Increment our counters + ptr += data_bytes; + size -= data_bytes; + } else { + int err = errno; + GetErrorLogInstance() << "Error reading lead " << lead << " errno:" << err + << " " << ::strerror(err) << "\n"; + return 0; + } + } + return (bufsize - size) / sizeof(uint64_t); +} + +const bool PerfCounters::kSupported = true; + +// Initializes libpfm only on the first call. Returns whether that single +// initialization was successful. +bool PerfCounters::Initialize() { + // Function-scope static gets initialized only once on first call. + static const bool success = []() { + return pfm_initialize() == PFM_SUCCESS; + }(); + return success; +} + +bool PerfCounters::IsCounterSupported(const std::string& name) { + Initialize(); + perf_event_attr_t attr; + std::memset(&attr, 0, sizeof(attr)); + pfm_perf_encode_arg_t arg; + std::memset(&arg, 0, sizeof(arg)); + arg.attr = &attr; + const int mode = PFM_PLM3; // user mode only + int ret = pfm_get_os_event_encoding(name.c_str(), mode, PFM_OS_PERF_EVENT_EXT, + &arg); + return (ret == PFM_SUCCESS); +} + +PerfCounters PerfCounters::Create( + const std::vector& counter_names) { + if (!counter_names.empty()) { + Initialize(); + } + + // Valid counters will populate these arrays but we start empty + std::vector valid_names; + std::vector counter_ids; + std::vector leader_ids; + + // Resize to the maximum possible + valid_names.reserve(counter_names.size()); + counter_ids.reserve(counter_names.size()); + + const int kCounterMode = PFM_PLM3; // user mode only + + // Group leads will be assigned on demand. The idea is that once we cannot + // create a counter descriptor, the reason is that this group has maxed out + // so we set the group_id again to -1 and retry - giving the algorithm a + // chance to create a new group leader to hold the next set of counters. + int group_id = -1; + + // Loop through all performance counters + for (size_t i = 0; i < counter_names.size(); ++i) { + // we are about to push into the valid names vector + // check if we did not reach the maximum + if (valid_names.size() == PerfCounterValues::kMaxCounters) { + // Log a message if we maxed out and stop adding + GetErrorLogInstance() + << counter_names.size() << " counters were requested. The maximum is " + << PerfCounterValues::kMaxCounters << " and " << valid_names.size() + << " were already added. All remaining counters will be ignored\n"; + // stop the loop and return what we have already + break; + } + + // Check if this name is empty + const auto& name = counter_names[i]; + if (name.empty()) { + GetErrorLogInstance() + << "A performance counter name was the empty string\n"; + continue; + } + + // Here first means first in group, ie the group leader + const bool is_first = (group_id < 0); + + // This struct will be populated by libpfm from the counter string + // and then fed into the syscall perf_event_open + struct perf_event_attr attr {}; + attr.size = sizeof(attr); + + // This is the input struct to libpfm. + pfm_perf_encode_arg_t arg{}; + arg.attr = &attr; + const int pfm_get = pfm_get_os_event_encoding(name.c_str(), kCounterMode, + PFM_OS_PERF_EVENT, &arg); + if (pfm_get != PFM_SUCCESS) { + GetErrorLogInstance() + << "Unknown performance counter name: " << name << "\n"; + continue; + } + + // We then proceed to populate the remaining fields in our attribute struct + // Note: the man page for perf_event_create suggests inherit = true and + // read_format = PERF_FORMAT_GROUP don't work together, but that's not the + // case. + attr.disabled = is_first; + attr.inherit = true; + attr.pinned = is_first; + attr.exclude_kernel = true; + attr.exclude_user = false; + attr.exclude_hv = true; + + // Read all counters in a group in one read. + attr.read_format = PERF_FORMAT_GROUP; //| PERF_FORMAT_TOTAL_TIME_ENABLED | + // PERF_FORMAT_TOTAL_TIME_RUNNING; + + int id = -1; + while (id < 0) { + static constexpr size_t kNrOfSyscallRetries = 5; + // Retry syscall as it was interrupted often (b/64774091). + for (size_t num_retries = 0; num_retries < kNrOfSyscallRetries; + ++num_retries) { + id = perf_event_open(&attr, 0, -1, group_id, 0); + if (id >= 0 || errno != EINTR) { + break; + } + } + if (id < 0) { + // If the file descriptor is negative we might have reached a limit + // in the current group. Set the group_id to -1 and retry + if (group_id >= 0) { + // Create a new group + group_id = -1; + } else { + // At this point we have already retried to set a new group id and + // failed. We then give up. + break; + } + } + } + + // We failed to get a new file descriptor. We might have reached a hard + // hardware limit that cannot be resolved even with group multiplexing + if (id < 0) { + GetErrorLogInstance() << "***WARNING** Failed to get a file descriptor " + "for performance counter " + << name << ". Ignoring\n"; + + // We give up on this counter but try to keep going + // as the others would be fine + continue; + } + if (group_id < 0) { + // This is a leader, store and assign it to the current file descriptor + leader_ids.push_back(id); + group_id = id; + } + // This is a valid counter, add it to our descriptor's list + counter_ids.push_back(id); + valid_names.push_back(name); + } + + // Loop through all group leaders activating them + // There is another option of starting ALL counters in a process but + // that would be far reaching an intrusion. If the user is using PMCs + // by themselves then this would have a side effect on them. It is + // friendlier to loop through all groups individually. + for (int lead : leader_ids) { + if (ioctl(lead, PERF_EVENT_IOC_ENABLE) != 0) { + // This should never happen but if it does, we give up on the + // entire batch as recovery would be a mess. + GetErrorLogInstance() << "***WARNING*** Failed to start counters. " + "Claring out all counters.\n"; + + // Close all performance counters + for (int id : counter_ids) { + ::close(id); + } + + // Return an empty object so our internal state is still good and + // the process can continue normally without impact + return NoCounters(); + } + } + + return PerfCounters(std::move(valid_names), std::move(counter_ids), + std::move(leader_ids)); +} + +void PerfCounters::CloseCounters() const { + if (counter_ids_.empty()) { + return; + } + for (int lead : leader_ids_) { + ioctl(lead, PERF_EVENT_IOC_DISABLE); + } + for (int fd : counter_ids_) { + close(fd); + } +} +#else // defined HAVE_LIBPFM +size_t PerfCounterValues::Read(const std::vector&) { return 0; } + +const bool PerfCounters::kSupported = false; + +bool PerfCounters::Initialize() { return false; } + +bool PerfCounters::IsCounterSupported(const std::string&) { return false; } + +PerfCounters PerfCounters::Create( + const std::vector& counter_names) { + if (!counter_names.empty()) { + GetErrorLogInstance() << "Performance counters not supported.\n"; + } + return NoCounters(); +} + +void PerfCounters::CloseCounters() const {} +#endif // defined HAVE_LIBPFM + +PerfCountersMeasurement::PerfCountersMeasurement( + const std::vector& counter_names) + : start_values_(counter_names.size()), end_values_(counter_names.size()) { + counters_ = PerfCounters::Create(counter_names); +} + +PerfCounters& PerfCounters::operator=(PerfCounters&& other) noexcept { + if (this != &other) { + CloseCounters(); + + counter_ids_ = std::move(other.counter_ids_); + leader_ids_ = std::move(other.leader_ids_); + counter_names_ = std::move(other.counter_names_); + } + return *this; +} +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/src/perf_counters.h b/third_party/benchmark/src/perf_counters.h new file mode 100644 index 0000000..bf5eb6b --- /dev/null +++ b/third_party/benchmark/src/perf_counters.h @@ -0,0 +1,200 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BENCHMARK_PERF_COUNTERS_H +#define BENCHMARK_PERF_COUNTERS_H + +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" +#include "log.h" +#include "mutex.h" + +#ifndef BENCHMARK_OS_WINDOWS +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(push) +// C4251: needs to have dll-interface to be used by clients of class +#pragma warning(disable : 4251) +#endif + +namespace benchmark { +namespace internal { + +// Typically, we can only read a small number of counters. There is also a +// padding preceding counter values, when reading multiple counters with one +// syscall (which is desirable). PerfCounterValues abstracts these details. +// The implementation ensures the storage is inlined, and allows 0-based +// indexing into the counter values. +// The object is used in conjunction with a PerfCounters object, by passing it +// to Snapshot(). The Read() method relocates individual reads, discarding +// the initial padding from each group leader in the values buffer such that +// all user accesses through the [] operator are correct. +class BENCHMARK_EXPORT PerfCounterValues { + public: + explicit PerfCounterValues(size_t nr_counters) : nr_counters_(nr_counters) { + BM_CHECK_LE(nr_counters_, kMaxCounters); + } + + // We are reading correctly now so the values don't need to skip padding + uint64_t operator[](size_t pos) const { return values_[pos]; } + + // Increased the maximum to 32 only since the buffer + // is std::array<> backed + static constexpr size_t kMaxCounters = 32; + + private: + friend class PerfCounters; + // Get the byte buffer in which perf counters can be captured. + // This is used by PerfCounters::Read + std::pair get_data_buffer() { + return {reinterpret_cast(values_.data()), + sizeof(uint64_t) * (kPadding + nr_counters_)}; + } + + // This reading is complex and as the goal of this class is to + // abstract away the intrincacies of the reading process, this is + // a better place for it + size_t Read(const std::vector& leaders); + + // Move the padding to 2 due to the reading algorithm (1st padding plus a + // current read padding) + static constexpr size_t kPadding = 2; + std::array values_; + const size_t nr_counters_; +}; + +// Collect PMU counters. The object, once constructed, is ready to be used by +// calling read(). PMU counter collection is enabled from the time create() is +// called, to obtain the object, until the object's destructor is called. +class BENCHMARK_EXPORT PerfCounters final { + public: + // True iff this platform supports performance counters. + static const bool kSupported; + + // Returns an empty object + static PerfCounters NoCounters() { return PerfCounters(); } + + ~PerfCounters() { CloseCounters(); } + PerfCounters() = default; + PerfCounters(PerfCounters&&) = default; + PerfCounters(const PerfCounters&) = delete; + PerfCounters& operator=(PerfCounters&&) noexcept; + PerfCounters& operator=(const PerfCounters&) = delete; + + // Platform-specific implementations may choose to do some library + // initialization here. + static bool Initialize(); + + // Check if the given counter is supported, if the app wants to + // check before passing + static bool IsCounterSupported(const std::string& name); + + // Return a PerfCounters object ready to read the counters with the names + // specified. The values are user-mode only. The counter name format is + // implementation and OS specific. + // In case of failure, this method will in the worst case return an + // empty object whose state will still be valid. + static PerfCounters Create(const std::vector& counter_names); + + // Take a snapshot of the current value of the counters into the provided + // valid PerfCounterValues storage. The values are populated such that: + // names()[i]'s value is (*values)[i] + BENCHMARK_ALWAYS_INLINE bool Snapshot(PerfCounterValues* values) const { +#ifndef BENCHMARK_OS_WINDOWS + assert(values != nullptr); + return values->Read(leader_ids_) == counter_ids_.size(); +#else + (void)values; + return false; +#endif + } + + const std::vector& names() const { return counter_names_; } + size_t num_counters() const { return counter_names_.size(); } + + private: + PerfCounters(const std::vector& counter_names, + std::vector&& counter_ids, std::vector&& leader_ids) + : counter_ids_(std::move(counter_ids)), + leader_ids_(std::move(leader_ids)), + counter_names_(counter_names) {} + + void CloseCounters() const; + + std::vector counter_ids_; + std::vector leader_ids_; + std::vector counter_names_; +}; + +// Typical usage of the above primitives. +class BENCHMARK_EXPORT PerfCountersMeasurement final { + public: + PerfCountersMeasurement(const std::vector& counter_names); + + size_t num_counters() const { return counters_.num_counters(); } + + std::vector names() const { return counters_.names(); } + + BENCHMARK_ALWAYS_INLINE bool Start() { + if (num_counters() == 0) return true; + // Tell the compiler to not move instructions above/below where we take + // the snapshot. + ClobberMemory(); + valid_read_ &= counters_.Snapshot(&start_values_); + ClobberMemory(); + + return valid_read_; + } + + BENCHMARK_ALWAYS_INLINE bool Stop( + std::vector>& measurements) { + if (num_counters() == 0) return true; + // Tell the compiler to not move instructions above/below where we take + // the snapshot. + ClobberMemory(); + valid_read_ &= counters_.Snapshot(&end_values_); + ClobberMemory(); + + for (size_t i = 0; i < counters_.names().size(); ++i) { + double measurement = static_cast(end_values_[i]) - + static_cast(start_values_[i]); + measurements.push_back({counters_.names()[i], measurement}); + } + + return valid_read_; + } + + private: + PerfCounters counters_; + bool valid_read_ = true; + PerfCounterValues start_values_; + PerfCounterValues end_values_; +}; + +} // namespace internal +} // namespace benchmark + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#endif // BENCHMARK_PERF_COUNTERS_H diff --git a/third_party/benchmark/src/re.h b/third_party/benchmark/src/re.h new file mode 100644 index 0000000..9afb869 --- /dev/null +++ b/third_party/benchmark/src/re.h @@ -0,0 +1,158 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BENCHMARK_RE_H_ +#define BENCHMARK_RE_H_ + +#include "internal_macros.h" + +// clang-format off + +#if !defined(HAVE_STD_REGEX) && \ + !defined(HAVE_GNU_POSIX_REGEX) && \ + !defined(HAVE_POSIX_REGEX) + // No explicit regex selection; detect based on builtin hints. + #if defined(BENCHMARK_OS_LINUX) || defined(BENCHMARK_OS_APPLE) + #define HAVE_POSIX_REGEX 1 + #elif __cplusplus >= 199711L + #define HAVE_STD_REGEX 1 + #endif +#endif + +// Prefer C regex libraries when compiling w/o exceptions so that we can +// correctly report errors. +#if defined(BENCHMARK_HAS_NO_EXCEPTIONS) && \ + defined(HAVE_STD_REGEX) && \ + (defined(HAVE_GNU_POSIX_REGEX) || defined(HAVE_POSIX_REGEX)) + #undef HAVE_STD_REGEX +#endif + +#if defined(HAVE_STD_REGEX) + #include +#elif defined(HAVE_GNU_POSIX_REGEX) + #include +#elif defined(HAVE_POSIX_REGEX) + #include +#else +#error No regular expression backend was found! +#endif + +// clang-format on + +#include + +#include "check.h" + +namespace benchmark { + +// A wrapper around the POSIX regular expression API that provides automatic +// cleanup +class Regex { + public: + Regex() : init_(false) {} + + ~Regex(); + + // Compile a regular expression matcher from spec. Returns true on success. + // + // On failure (and if error is not nullptr), error is populated with a human + // readable error message if an error occurs. + bool Init(const std::string& spec, std::string* error); + + // Returns whether str matches the compiled regular expression. + bool Match(const std::string& str); + + private: + bool init_; +// Underlying regular expression object +#if defined(HAVE_STD_REGEX) + std::regex re_; +#elif defined(HAVE_POSIX_REGEX) || defined(HAVE_GNU_POSIX_REGEX) + regex_t re_; +#else +#error No regular expression backend implementation available +#endif +}; + +#if defined(HAVE_STD_REGEX) + +inline bool Regex::Init(const std::string& spec, std::string* error) { +#ifdef BENCHMARK_HAS_NO_EXCEPTIONS + ((void)error); // suppress unused warning +#else + try { +#endif + re_ = std::regex(spec, std::regex_constants::extended); + init_ = true; +#ifndef BENCHMARK_HAS_NO_EXCEPTIONS +} +catch (const std::regex_error& e) { + if (error) { + *error = e.what(); + } +} +#endif +return init_; +} + +inline Regex::~Regex() {} + +inline bool Regex::Match(const std::string& str) { + if (!init_) { + return false; + } + return std::regex_search(str, re_); +} + +#else +inline bool Regex::Init(const std::string& spec, std::string* error) { + int ec = regcomp(&re_, spec.c_str(), REG_EXTENDED | REG_NOSUB); + if (ec != 0) { + if (error) { + size_t needed = regerror(ec, &re_, nullptr, 0); + char* errbuf = new char[needed]; + regerror(ec, &re_, errbuf, needed); + + // regerror returns the number of bytes necessary to null terminate + // the string, so we move that when assigning to error. + BM_CHECK_NE(needed, 0); + error->assign(errbuf, needed - 1); + + delete[] errbuf; + } + + return false; + } + + init_ = true; + return true; +} + +inline Regex::~Regex() { + if (init_) { + regfree(&re_); + } +} + +inline bool Regex::Match(const std::string& str) { + if (!init_) { + return false; + } + return regexec(&re_, str.c_str(), 0, nullptr, 0) == 0; +} +#endif + +} // end namespace benchmark + +#endif // BENCHMARK_RE_H_ diff --git a/third_party/benchmark/src/reporter.cc b/third_party/benchmark/src/reporter.cc new file mode 100644 index 0000000..076bc31 --- /dev/null +++ b/third_party/benchmark/src/reporter.cc @@ -0,0 +1,118 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" +#include "string_util.h" +#include "timers.h" + +namespace benchmark { + +BenchmarkReporter::BenchmarkReporter() + : output_stream_(&std::cout), error_stream_(&std::cerr) {} + +BenchmarkReporter::~BenchmarkReporter() {} + +void BenchmarkReporter::PrintBasicContext(std::ostream *out, + Context const &context) { + BM_CHECK(out) << "cannot be null"; + auto &Out = *out; + +#ifndef BENCHMARK_OS_QURT + // Date/time information is not available on QuRT. + // Attempting to get it via this call cause the binary to crash. + Out << LocalDateTimeString() << "\n"; +#endif + + if (context.executable_name) + Out << "Running " << context.executable_name << "\n"; + + const CPUInfo &info = context.cpu_info; + Out << "Run on (" << info.num_cpus << " X " + << (info.cycles_per_second / 1000000.0) << " MHz CPU " + << ((info.num_cpus > 1) ? "s" : "") << ")\n"; + if (info.caches.size() != 0) { + Out << "CPU Caches:\n"; + for (auto &CInfo : info.caches) { + Out << " L" << CInfo.level << " " << CInfo.type << " " + << (CInfo.size / 1024) << " KiB"; + if (CInfo.num_sharing != 0) + Out << " (x" << (info.num_cpus / CInfo.num_sharing) << ")"; + Out << "\n"; + } + } + if (!info.load_avg.empty()) { + Out << "Load Average: "; + for (auto It = info.load_avg.begin(); It != info.load_avg.end();) { + Out << StrFormat("%.2f", *It++); + if (It != info.load_avg.end()) Out << ", "; + } + Out << "\n"; + } + + std::map *global_context = + internal::GetGlobalContext(); + + if (global_context != nullptr) { + for (const auto &kv : *global_context) { + Out << kv.first << ": " << kv.second << "\n"; + } + } + + if (CPUInfo::Scaling::ENABLED == info.scaling) { + Out << "***WARNING*** CPU scaling is enabled, the benchmark " + "real time measurements may be noisy and will incur extra " + "overhead.\n"; + } + +#ifndef NDEBUG + Out << "***WARNING*** Library was built as DEBUG. Timings may be " + "affected.\n"; +#endif +} + +// No initializer because it's already initialized to NULL. +const char *BenchmarkReporter::Context::executable_name; + +BenchmarkReporter::Context::Context() + : cpu_info(CPUInfo::Get()), sys_info(SystemInfo::Get()) {} + +std::string BenchmarkReporter::Run::benchmark_name() const { + std::string name = run_name.str(); + if (run_type == RT_Aggregate) { + name += "_" + aggregate_name; + } + return name; +} + +double BenchmarkReporter::Run::GetAdjustedRealTime() const { + double new_time = real_accumulated_time * GetTimeUnitMultiplier(time_unit); + if (iterations != 0) new_time /= static_cast(iterations); + return new_time; +} + +double BenchmarkReporter::Run::GetAdjustedCPUTime() const { + double new_time = cpu_accumulated_time * GetTimeUnitMultiplier(time_unit); + if (iterations != 0) new_time /= static_cast(iterations); + return new_time; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/statistics.cc b/third_party/benchmark/src/statistics.cc new file mode 100644 index 0000000..16b6026 --- /dev/null +++ b/third_party/benchmark/src/statistics.cc @@ -0,0 +1,214 @@ +// Copyright 2016 Ismael Jimenez Martinez. All rights reserved. +// Copyright 2017 Roman Lebedev. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "statistics.h" + +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" + +namespace benchmark { + +auto StatisticsSum = [](const std::vector& v) { + return std::accumulate(v.begin(), v.end(), 0.0); +}; + +double StatisticsMean(const std::vector& v) { + if (v.empty()) return 0.0; + return StatisticsSum(v) * (1.0 / static_cast(v.size())); +} + +double StatisticsMedian(const std::vector& v) { + if (v.size() < 3) return StatisticsMean(v); + std::vector copy(v); + + auto center = copy.begin() + v.size() / 2; + std::nth_element(copy.begin(), center, copy.end()); + + // Did we have an odd number of samples? If yes, then center is the median. + // If not, then we are looking for the average between center and the value + // before. Instead of resorting, we just look for the max value before it, + // which is not necessarily the element immediately preceding `center` Since + // `copy` is only partially sorted by `nth_element`. + if (v.size() % 2 == 1) return *center; + auto center2 = std::max_element(copy.begin(), center); + return (*center + *center2) / 2.0; +} + +// Return the sum of the squares of this sample set +auto SumSquares = [](const std::vector& v) { + return std::inner_product(v.begin(), v.end(), v.begin(), 0.0); +}; + +auto Sqr = [](const double dat) { return dat * dat; }; +auto Sqrt = [](const double dat) { + // Avoid NaN due to imprecision in the calculations + if (dat < 0.0) return 0.0; + return std::sqrt(dat); +}; + +double StatisticsStdDev(const std::vector& v) { + const auto mean = StatisticsMean(v); + if (v.empty()) return mean; + + // Sample standard deviation is undefined for n = 1 + if (v.size() == 1) return 0.0; + + const double avg_squares = + SumSquares(v) * (1.0 / static_cast(v.size())); + return Sqrt(static_cast(v.size()) / + (static_cast(v.size()) - 1.0) * + (avg_squares - Sqr(mean))); +} + +double StatisticsCV(const std::vector& v) { + if (v.size() < 2) return 0.0; + + const auto stddev = StatisticsStdDev(v); + const auto mean = StatisticsMean(v); + + if (std::fpclassify(mean) == FP_ZERO) return 0.0; + + return stddev / mean; +} + +std::vector ComputeStats( + const std::vector& reports) { + typedef BenchmarkReporter::Run Run; + std::vector results; + + auto error_count = std::count_if(reports.begin(), reports.end(), + [](Run const& run) { return run.skipped; }); + + if (reports.size() - static_cast(error_count) < 2) { + // We don't report aggregated data if there was a single run. + return results; + } + + // Accumulators. + std::vector real_accumulated_time_stat; + std::vector cpu_accumulated_time_stat; + + real_accumulated_time_stat.reserve(reports.size()); + cpu_accumulated_time_stat.reserve(reports.size()); + + // All repetitions should be run with the same number of iterations so we + // can take this information from the first benchmark. + const IterationCount run_iterations = reports.front().iterations; + // create stats for user counters + struct CounterStat { + Counter c; + std::vector s; + }; + std::map counter_stats; + for (Run const& r : reports) { + for (auto const& cnt : r.counters) { + auto it = counter_stats.find(cnt.first); + if (it == counter_stats.end()) { + it = counter_stats + .emplace(cnt.first, + CounterStat{cnt.second, std::vector{}}) + .first; + it->second.s.reserve(reports.size()); + } else { + BM_CHECK_EQ(it->second.c.flags, cnt.second.flags); + } + } + } + + // Populate the accumulators. + for (Run const& run : reports) { + BM_CHECK_EQ(reports[0].benchmark_name(), run.benchmark_name()); + BM_CHECK_EQ(run_iterations, run.iterations); + if (run.skipped) continue; + real_accumulated_time_stat.emplace_back(run.real_accumulated_time); + cpu_accumulated_time_stat.emplace_back(run.cpu_accumulated_time); + // user counters + for (auto const& cnt : run.counters) { + auto it = counter_stats.find(cnt.first); + BM_CHECK_NE(it, counter_stats.end()); + it->second.s.emplace_back(cnt.second); + } + } + + // Only add label if it is same for all runs + std::string report_label = reports[0].report_label; + for (std::size_t i = 1; i < reports.size(); i++) { + if (reports[i].report_label != report_label) { + report_label = ""; + break; + } + } + + const double iteration_rescale_factor = + double(reports.size()) / double(run_iterations); + + for (const auto& Stat : *reports[0].statistics) { + // Get the data from the accumulator to BenchmarkReporter::Run's. + Run data; + data.run_name = reports[0].run_name; + data.family_index = reports[0].family_index; + data.per_family_instance_index = reports[0].per_family_instance_index; + data.run_type = BenchmarkReporter::Run::RT_Aggregate; + data.threads = reports[0].threads; + data.repetitions = reports[0].repetitions; + data.repetition_index = Run::no_repetition_index; + data.aggregate_name = Stat.name_; + data.aggregate_unit = Stat.unit_; + data.report_label = report_label; + + // It is incorrect to say that an aggregate is computed over + // run's iterations, because those iterations already got averaged. + // Similarly, if there are N repetitions with 1 iterations each, + // an aggregate will be computed over N measurements, not 1. + // Thus it is best to simply use the count of separate reports. + data.iterations = static_cast(reports.size()); + + data.real_accumulated_time = Stat.compute_(real_accumulated_time_stat); + data.cpu_accumulated_time = Stat.compute_(cpu_accumulated_time_stat); + + if (data.aggregate_unit == StatisticUnit::kTime) { + // We will divide these times by data.iterations when reporting, but the + // data.iterations is not necessarily the scale of these measurements, + // because in each repetition, these timers are sum over all the iters. + // And if we want to say that the stats are over N repetitions and not + // M iterations, we need to multiply these by (N/M). + data.real_accumulated_time *= iteration_rescale_factor; + data.cpu_accumulated_time *= iteration_rescale_factor; + } + + data.time_unit = reports[0].time_unit; + + // user counters + for (auto const& kv : counter_stats) { + // Do *NOT* rescale the custom counters. They are already properly scaled. + const auto uc_stat = Stat.compute_(kv.second.s); + auto c = Counter(uc_stat, counter_stats[kv.first].c.flags, + counter_stats[kv.first].c.oneK); + data.counters[kv.first] = c; + } + + results.push_back(data); + } + + return results; +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/statistics.h b/third_party/benchmark/src/statistics.h new file mode 100644 index 0000000..6e5560e --- /dev/null +++ b/third_party/benchmark/src/statistics.h @@ -0,0 +1,44 @@ +// Copyright 2016 Ismael Jimenez Martinez. All rights reserved. +// Copyright 2017 Roman Lebedev. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef STATISTICS_H_ +#define STATISTICS_H_ + +#include + +#include "benchmark/benchmark.h" + +namespace benchmark { + +// Return a vector containing the mean, median and standard deviation +// information (and any user-specified info) for the specified list of reports. +// If 'reports' contains less than two non-errored runs an empty vector is +// returned +BENCHMARK_EXPORT +std::vector ComputeStats( + const std::vector& reports); + +BENCHMARK_EXPORT +double StatisticsMean(const std::vector& v); +BENCHMARK_EXPORT +double StatisticsMedian(const std::vector& v); +BENCHMARK_EXPORT +double StatisticsStdDev(const std::vector& v); +BENCHMARK_EXPORT +double StatisticsCV(const std::vector& v); + +} // end namespace benchmark + +#endif // STATISTICS_H_ diff --git a/third_party/benchmark/src/string_util.cc b/third_party/benchmark/src/string_util.cc new file mode 100644 index 0000000..9ba63a7 --- /dev/null +++ b/third_party/benchmark/src/string_util.cc @@ -0,0 +1,254 @@ +#include "string_util.h" + +#include +#ifdef BENCHMARK_STL_ANDROID_GNUSTL +#include +#endif +#include +#include +#include +#include +#include + +#include "arraysize.h" +#include "benchmark/benchmark.h" + +namespace benchmark { +namespace { +// kilo, Mega, Giga, Tera, Peta, Exa, Zetta, Yotta. +const char* const kBigSIUnits[] = {"k", "M", "G", "T", "P", "E", "Z", "Y"}; +// Kibi, Mebi, Gibi, Tebi, Pebi, Exbi, Zebi, Yobi. +const char* const kBigIECUnits[] = {"Ki", "Mi", "Gi", "Ti", + "Pi", "Ei", "Zi", "Yi"}; +// milli, micro, nano, pico, femto, atto, zepto, yocto. +const char* const kSmallSIUnits[] = {"m", "u", "n", "p", "f", "a", "z", "y"}; + +// We require that all three arrays have the same size. +static_assert(arraysize(kBigSIUnits) == arraysize(kBigIECUnits), + "SI and IEC unit arrays must be the same size"); +static_assert(arraysize(kSmallSIUnits) == arraysize(kBigSIUnits), + "Small SI and Big SI unit arrays must be the same size"); + +static const int64_t kUnitsSize = arraysize(kBigSIUnits); + +void ToExponentAndMantissa(double val, int precision, double one_k, + std::string* mantissa, int64_t* exponent) { + std::stringstream mantissa_stream; + + if (val < 0) { + mantissa_stream << "-"; + val = -val; + } + + // Adjust threshold so that it never excludes things which can't be rendered + // in 'precision' digits. + const double adjusted_threshold = + std::max(1.0, 1.0 / std::pow(10.0, precision)); + const double big_threshold = (adjusted_threshold * one_k) - 1; + const double small_threshold = adjusted_threshold; + // Values in ]simple_threshold,small_threshold[ will be printed as-is + const double simple_threshold = 0.01; + + if (val > big_threshold) { + // Positive powers + double scaled = val; + for (size_t i = 0; i < arraysize(kBigSIUnits); ++i) { + scaled /= one_k; + if (scaled <= big_threshold) { + mantissa_stream << scaled; + *exponent = static_cast(i + 1); + *mantissa = mantissa_stream.str(); + return; + } + } + mantissa_stream << val; + *exponent = 0; + } else if (val < small_threshold) { + // Negative powers + if (val < simple_threshold) { + double scaled = val; + for (size_t i = 0; i < arraysize(kSmallSIUnits); ++i) { + scaled *= one_k; + if (scaled >= small_threshold) { + mantissa_stream << scaled; + *exponent = -static_cast(i + 1); + *mantissa = mantissa_stream.str(); + return; + } + } + } + mantissa_stream << val; + *exponent = 0; + } else { + mantissa_stream << val; + *exponent = 0; + } + *mantissa = mantissa_stream.str(); +} + +std::string ExponentToPrefix(int64_t exponent, bool iec) { + if (exponent == 0) return ""; + + const int64_t index = (exponent > 0 ? exponent - 1 : -exponent - 1); + if (index >= kUnitsSize) return ""; + + const char* const* array = + (exponent > 0 ? (iec ? kBigIECUnits : kBigSIUnits) : kSmallSIUnits); + + return std::string(array[index]); +} + +std::string ToBinaryStringFullySpecified(double value, int precision, + Counter::OneK one_k) { + std::string mantissa; + int64_t exponent; + ToExponentAndMantissa(value, precision, + one_k == Counter::kIs1024 ? 1024.0 : 1000.0, &mantissa, + &exponent); + return mantissa + ExponentToPrefix(exponent, one_k == Counter::kIs1024); +} + +std::string StrFormatImp(const char* msg, va_list args) { + // we might need a second shot at this, so pre-emptivly make a copy + va_list args_cp; + va_copy(args_cp, args); + + // TODO(ericwf): use std::array for first attempt to avoid one memory + // allocation guess what the size might be + std::array local_buff; + + // 2015-10-08: vsnprintf is used instead of snd::vsnprintf due to a limitation + // in the android-ndk + auto ret = vsnprintf(local_buff.data(), local_buff.size(), msg, args_cp); + + va_end(args_cp); + + // handle empty expansion + if (ret == 0) return std::string{}; + if (static_cast(ret) < local_buff.size()) + return std::string(local_buff.data()); + + // we did not provide a long enough buffer on our first attempt. + // add 1 to size to account for null-byte in size cast to prevent overflow + std::size_t size = static_cast(ret) + 1; + auto buff_ptr = std::unique_ptr(new char[size]); + // 2015-10-08: vsnprintf is used instead of snd::vsnprintf due to a limitation + // in the android-ndk + vsnprintf(buff_ptr.get(), size, msg, args); + return std::string(buff_ptr.get()); +} + +} // end namespace + +std::string HumanReadableNumber(double n, Counter::OneK one_k) { + return ToBinaryStringFullySpecified(n, 1, one_k); +} + +std::string StrFormat(const char* format, ...) { + va_list args; + va_start(args, format); + std::string tmp = StrFormatImp(format, args); + va_end(args); + return tmp; +} + +std::vector StrSplit(const std::string& str, char delim) { + if (str.empty()) return {}; + std::vector ret; + size_t first = 0; + size_t next = str.find(delim); + for (; next != std::string::npos; + first = next + 1, next = str.find(delim, first)) { + ret.push_back(str.substr(first, next - first)); + } + ret.push_back(str.substr(first)); + return ret; +} + +#ifdef BENCHMARK_STL_ANDROID_GNUSTL +/* + * GNU STL in Android NDK lacks support for some C++11 functions, including + * stoul, stoi, stod. We reimplement them here using C functions strtoul, + * strtol, strtod. Note that reimplemented functions are in benchmark:: + * namespace, not std:: namespace. + */ +unsigned long stoul(const std::string& str, size_t* pos, int base) { + /* Record previous errno */ + const int oldErrno = errno; + errno = 0; + + const char* strStart = str.c_str(); + char* strEnd = const_cast(strStart); + const unsigned long result = strtoul(strStart, &strEnd, base); + + const int strtoulErrno = errno; + /* Restore previous errno */ + errno = oldErrno; + + /* Check for errors and return */ + if (strtoulErrno == ERANGE) { + throw std::out_of_range("stoul failed: " + str + + " is outside of range of unsigned long"); + } else if (strEnd == strStart || strtoulErrno != 0) { + throw std::invalid_argument("stoul failed: " + str + " is not an integer"); + } + if (pos != nullptr) { + *pos = static_cast(strEnd - strStart); + } + return result; +} + +int stoi(const std::string& str, size_t* pos, int base) { + /* Record previous errno */ + const int oldErrno = errno; + errno = 0; + + const char* strStart = str.c_str(); + char* strEnd = const_cast(strStart); + const long result = strtol(strStart, &strEnd, base); + + const int strtolErrno = errno; + /* Restore previous errno */ + errno = oldErrno; + + /* Check for errors and return */ + if (strtolErrno == ERANGE || long(int(result)) != result) { + throw std::out_of_range("stoul failed: " + str + + " is outside of range of int"); + } else if (strEnd == strStart || strtolErrno != 0) { + throw std::invalid_argument("stoul failed: " + str + " is not an integer"); + } + if (pos != nullptr) { + *pos = static_cast(strEnd - strStart); + } + return int(result); +} + +double stod(const std::string& str, size_t* pos) { + /* Record previous errno */ + const int oldErrno = errno; + errno = 0; + + const char* strStart = str.c_str(); + char* strEnd = const_cast(strStart); + const double result = strtod(strStart, &strEnd); + + /* Restore previous errno */ + const int strtodErrno = errno; + errno = oldErrno; + + /* Check for errors and return */ + if (strtodErrno == ERANGE) { + throw std::out_of_range("stoul failed: " + str + + " is outside of range of int"); + } else if (strEnd == strStart || strtodErrno != 0) { + throw std::invalid_argument("stoul failed: " + str + " is not an integer"); + } + if (pos != nullptr) { + *pos = static_cast(strEnd - strStart); + } + return result; +} +#endif + +} // end namespace benchmark diff --git a/third_party/benchmark/src/string_util.h b/third_party/benchmark/src/string_util.h new file mode 100644 index 0000000..731aa2c --- /dev/null +++ b/third_party/benchmark/src/string_util.h @@ -0,0 +1,70 @@ +#ifndef BENCHMARK_STRING_UTIL_H_ +#define BENCHMARK_STRING_UTIL_H_ + +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "benchmark/export.h" +#include "check.h" +#include "internal_macros.h" + +namespace benchmark { + +BENCHMARK_EXPORT +std::string HumanReadableNumber(double n, Counter::OneK one_k); + +BENCHMARK_EXPORT +#if defined(__MINGW32__) +__attribute__((format(__MINGW_PRINTF_FORMAT, 1, 2))) +#elif defined(__GNUC__) +__attribute__((format(printf, 1, 2))) +#endif +std::string +StrFormat(const char* format, ...); + +inline std::ostream& StrCatImp(std::ostream& out) BENCHMARK_NOEXCEPT { + return out; +} + +template +inline std::ostream& StrCatImp(std::ostream& out, First&& f, Rest&&... rest) { + out << std::forward(f); + return StrCatImp(out, std::forward(rest)...); +} + +template +inline std::string StrCat(Args&&... args) { + std::ostringstream ss; + StrCatImp(ss, std::forward(args)...); + return ss.str(); +} + +BENCHMARK_EXPORT +std::vector StrSplit(const std::string& str, char delim); + +// Disable lint checking for this block since it re-implements C functions. +// NOLINTBEGIN +#ifdef BENCHMARK_STL_ANDROID_GNUSTL +/* + * GNU STL in Android NDK lacks support for some C++11 functions, including + * stoul, stoi, stod. We reimplement them here using C functions strtoul, + * strtol, strtod. Note that reimplemented functions are in benchmark:: + * namespace, not std:: namespace. + */ +unsigned long stoul(const std::string& str, size_t* pos = nullptr, + int base = 10); +int stoi(const std::string& str, size_t* pos = nullptr, int base = 10); +double stod(const std::string& str, size_t* pos = nullptr); +#else +using std::stod; // NOLINT(misc-unused-using-decls) +using std::stoi; // NOLINT(misc-unused-using-decls) +using std::stoul; // NOLINT(misc-unused-using-decls) +#endif +// NOLINTEND + +} // end namespace benchmark + +#endif // BENCHMARK_STRING_UTIL_H_ diff --git a/third_party/benchmark/src/sysinfo.cc b/third_party/benchmark/src/sysinfo.cc new file mode 100644 index 0000000..7148598 --- /dev/null +++ b/third_party/benchmark/src/sysinfo.cc @@ -0,0 +1,877 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal_macros.h" + +#ifdef BENCHMARK_OS_WINDOWS +#if !defined(WINVER) || WINVER < 0x0600 +#undef WINVER +#define WINVER 0x0600 +#endif // WINVER handling +#include +#undef StrCat // Don't let StrCat in string_util.h be renamed to lstrcatA +#include +#include + +#include +#else +#include +#if !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +#include +#endif +#include +#include // this header must be included before 'sys/sysctl.h' to avoid compilation error on FreeBSD +#include +#if defined BENCHMARK_OS_FREEBSD || defined BENCHMARK_OS_MACOSX || \ + defined BENCHMARK_OS_NETBSD || defined BENCHMARK_OS_OPENBSD || \ + defined BENCHMARK_OS_DRAGONFLY +#define BENCHMARK_HAS_SYSCTL +#include +#endif +#endif +#if defined(BENCHMARK_OS_SOLARIS) +#include +#include +#endif +#if defined(BENCHMARK_OS_QNX) +#include +#endif +#if defined(BENCHMARK_OS_QURT) +#include +#endif +#if defined(BENCHMARK_HAS_PTHREAD_AFFINITY) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "check.h" +#include "cycleclock.h" +#include "internal_macros.h" +#include "log.h" +#include "string_util.h" +#include "timers.h" + +namespace benchmark { +namespace { + +void PrintImp(std::ostream& out) { out << std::endl; } + +template +void PrintImp(std::ostream& out, First&& f, Rest&&... rest) { + out << std::forward(f); + PrintImp(out, std::forward(rest)...); +} + +template +BENCHMARK_NORETURN void PrintErrorAndDie(Args&&... args) { + PrintImp(std::cerr, std::forward(args)...); + std::exit(EXIT_FAILURE); +} + +#ifdef BENCHMARK_HAS_SYSCTL + +/// ValueUnion - A type used to correctly alias the byte-for-byte output of +/// `sysctl` with the result type it's to be interpreted as. +struct ValueUnion { + union DataT { + int32_t int32_value; + int64_t int64_value; + // For correct aliasing of union members from bytes. + char bytes[8]; + }; + using DataPtr = std::unique_ptr; + + // The size of the data union member + its trailing array size. + std::size_t size; + DataPtr buff; + + public: + ValueUnion() : size(0), buff(nullptr, &std::free) {} + + explicit ValueUnion(std::size_t buff_size) + : size(sizeof(DataT) + buff_size), + buff(::new (std::malloc(size)) DataT(), &std::free) {} + + ValueUnion(ValueUnion&& other) = default; + + explicit operator bool() const { return bool(buff); } + + char* data() const { return buff->bytes; } + + std::string GetAsString() const { return std::string(data()); } + + int64_t GetAsInteger() const { + if (size == sizeof(buff->int32_value)) + return buff->int32_value; + else if (size == sizeof(buff->int64_value)) + return buff->int64_value; + BENCHMARK_UNREACHABLE(); + } + + template + std::array GetAsArray() { + const int arr_size = sizeof(T) * N; + BM_CHECK_LE(arr_size, size); + std::array arr; + std::memcpy(arr.data(), data(), arr_size); + return arr; + } +}; + +ValueUnion GetSysctlImp(std::string const& name) { +#if defined BENCHMARK_OS_OPENBSD + int mib[2]; + + mib[0] = CTL_HW; + if ((name == "hw.ncpu") || (name == "hw.cpuspeed")) { + ValueUnion buff(sizeof(int)); + + if (name == "hw.ncpu") { + mib[1] = HW_NCPU; + } else { + mib[1] = HW_CPUSPEED; + } + + if (sysctl(mib, 2, buff.data(), &buff.size, nullptr, 0) == -1) { + return ValueUnion(); + } + return buff; + } + return ValueUnion(); +#else + std::size_t cur_buff_size = 0; + if (sysctlbyname(name.c_str(), nullptr, &cur_buff_size, nullptr, 0) == -1) + return ValueUnion(); + + ValueUnion buff(cur_buff_size); + if (sysctlbyname(name.c_str(), buff.data(), &buff.size, nullptr, 0) == 0) + return buff; + return ValueUnion(); +#endif +} + +BENCHMARK_MAYBE_UNUSED +bool GetSysctl(std::string const& name, std::string* out) { + out->clear(); + auto buff = GetSysctlImp(name); + if (!buff) return false; + out->assign(buff.data()); + return true; +} + +template ::value>::type> +bool GetSysctl(std::string const& name, Tp* out) { + *out = 0; + auto buff = GetSysctlImp(name); + if (!buff) return false; + *out = static_cast(buff.GetAsInteger()); + return true; +} + +template +bool GetSysctl(std::string const& name, std::array* out) { + auto buff = GetSysctlImp(name); + if (!buff) return false; + *out = buff.GetAsArray(); + return true; +} +#endif + +template +bool ReadFromFile(std::string const& fname, ArgT* arg) { + *arg = ArgT(); + std::ifstream f(fname.c_str()); + if (!f.is_open()) return false; + f >> *arg; + return f.good(); +} + +CPUInfo::Scaling CpuScaling(int num_cpus) { + // We don't have a valid CPU count, so don't even bother. + if (num_cpus <= 0) return CPUInfo::Scaling::UNKNOWN; +#if defined(BENCHMARK_OS_QNX) + return CPUInfo::Scaling::UNKNOWN; +#elif !defined(BENCHMARK_OS_WINDOWS) + // On Linux, the CPUfreq subsystem exposes CPU information as files on the + // local file system. If reading the exported files fails, then we may not be + // running on Linux, so we silently ignore all the read errors. + std::string res; + for (int cpu = 0; cpu < num_cpus; ++cpu) { + std::string governor_file = + StrCat("/sys/devices/system/cpu/cpu", cpu, "/cpufreq/scaling_governor"); + if (ReadFromFile(governor_file, &res) && res != "performance") + return CPUInfo::Scaling::ENABLED; + } + return CPUInfo::Scaling::DISABLED; +#else + return CPUInfo::Scaling::UNKNOWN; +#endif +} + +int CountSetBitsInCPUMap(std::string val) { + auto CountBits = [](std::string part) { + using CPUMask = std::bitset; + part = "0x" + part; + CPUMask mask(benchmark::stoul(part, nullptr, 16)); + return static_cast(mask.count()); + }; + std::size_t pos; + int total = 0; + while ((pos = val.find(',')) != std::string::npos) { + total += CountBits(val.substr(0, pos)); + val = val.substr(pos + 1); + } + if (!val.empty()) { + total += CountBits(val); + } + return total; +} + +BENCHMARK_MAYBE_UNUSED +std::vector GetCacheSizesFromKVFS() { + std::vector res; + std::string dir = "/sys/devices/system/cpu/cpu0/cache/"; + int idx = 0; + while (true) { + CPUInfo::CacheInfo info; + std::string fpath = StrCat(dir, "index", idx++, "/"); + std::ifstream f(StrCat(fpath, "size").c_str()); + if (!f.is_open()) break; + std::string suffix; + f >> info.size; + if (f.fail()) + PrintErrorAndDie("Failed while reading file '", fpath, "size'"); + if (f.good()) { + f >> suffix; + if (f.bad()) + PrintErrorAndDie( + "Invalid cache size format: failed to read size suffix"); + else if (f && suffix != "K") + PrintErrorAndDie("Invalid cache size format: Expected bytes ", suffix); + else if (suffix == "K") + info.size *= 1024; + } + if (!ReadFromFile(StrCat(fpath, "type"), &info.type)) + PrintErrorAndDie("Failed to read from file ", fpath, "type"); + if (!ReadFromFile(StrCat(fpath, "level"), &info.level)) + PrintErrorAndDie("Failed to read from file ", fpath, "level"); + std::string map_str; + if (!ReadFromFile(StrCat(fpath, "shared_cpu_map"), &map_str)) + PrintErrorAndDie("Failed to read from file ", fpath, "shared_cpu_map"); + info.num_sharing = CountSetBitsInCPUMap(map_str); + res.push_back(info); + } + + return res; +} + +#ifdef BENCHMARK_OS_MACOSX +std::vector GetCacheSizesMacOSX() { + std::vector res; + std::array cache_counts{{0, 0, 0, 0}}; + GetSysctl("hw.cacheconfig", &cache_counts); + + struct { + std::string name; + std::string type; + int level; + int num_sharing; + } cases[] = {{"hw.l1dcachesize", "Data", 1, cache_counts[1]}, + {"hw.l1icachesize", "Instruction", 1, cache_counts[1]}, + {"hw.l2cachesize", "Unified", 2, cache_counts[2]}, + {"hw.l3cachesize", "Unified", 3, cache_counts[3]}}; + for (auto& c : cases) { + int val; + if (!GetSysctl(c.name, &val)) continue; + CPUInfo::CacheInfo info; + info.type = c.type; + info.level = c.level; + info.size = val; + info.num_sharing = c.num_sharing; + res.push_back(std::move(info)); + } + return res; +} +#elif defined(BENCHMARK_OS_WINDOWS) +std::vector GetCacheSizesWindows() { + std::vector res; + DWORD buffer_size = 0; + using PInfo = SYSTEM_LOGICAL_PROCESSOR_INFORMATION; + using CInfo = CACHE_DESCRIPTOR; + + using UPtr = std::unique_ptr; + GetLogicalProcessorInformation(nullptr, &buffer_size); + UPtr buff(static_cast(std::malloc(buffer_size)), &std::free); + if (!GetLogicalProcessorInformation(buff.get(), &buffer_size)) + PrintErrorAndDie("Failed during call to GetLogicalProcessorInformation: ", + GetLastError()); + + PInfo* it = buff.get(); + PInfo* end = buff.get() + (buffer_size / sizeof(PInfo)); + + for (; it != end; ++it) { + if (it->Relationship != RelationCache) continue; + using BitSet = std::bitset; + BitSet b(it->ProcessorMask); + // To prevent duplicates, only consider caches where CPU 0 is specified + if (!b.test(0)) continue; + const CInfo& cache = it->Cache; + CPUInfo::CacheInfo C; + C.num_sharing = static_cast(b.count()); + C.level = cache.Level; + C.size = static_cast(cache.Size); + C.type = "Unknown"; + switch (cache.Type) { +// Windows SDK version >= 10.0.26100.0 +// 0x0A000010 is the value of NTDDI_WIN11_GE +#if NTDDI_VERSION >= 0x0A000010 + case CacheUnknown: + break; +#endif + case CacheUnified: + C.type = "Unified"; + break; + case CacheInstruction: + C.type = "Instruction"; + break; + case CacheData: + C.type = "Data"; + break; + case CacheTrace: + C.type = "Trace"; + break; + } + res.push_back(C); + } + return res; +} +#elif BENCHMARK_OS_QNX +std::vector GetCacheSizesQNX() { + std::vector res; + struct cacheattr_entry* cache = SYSPAGE_ENTRY(cacheattr); + uint32_t const elsize = SYSPAGE_ELEMENT_SIZE(cacheattr); + int num = SYSPAGE_ENTRY_SIZE(cacheattr) / elsize; + for (int i = 0; i < num; ++i) { + CPUInfo::CacheInfo info; + switch (cache->flags) { + case CACHE_FLAG_INSTR: + info.type = "Instruction"; + info.level = 1; + break; + case CACHE_FLAG_DATA: + info.type = "Data"; + info.level = 1; + break; + case CACHE_FLAG_UNIFIED: + info.type = "Unified"; + info.level = 2; + break; + case CACHE_FLAG_SHARED: + info.type = "Shared"; + info.level = 3; + break; + default: + continue; + break; + } + info.size = cache->line_size * cache->num_lines; + info.num_sharing = 0; + res.push_back(std::move(info)); + cache = SYSPAGE_ARRAY_ADJ_OFFSET(cacheattr, cache, elsize); + } + return res; +} +#endif + +std::vector GetCacheSizes() { +#ifdef BENCHMARK_OS_MACOSX + return GetCacheSizesMacOSX(); +#elif defined(BENCHMARK_OS_WINDOWS) + return GetCacheSizesWindows(); +#elif defined(BENCHMARK_OS_QNX) + return GetCacheSizesQNX(); +#elif defined(BENCHMARK_OS_QURT) + return std::vector(); +#else + return GetCacheSizesFromKVFS(); +#endif +} + +std::string GetSystemName() { +#if defined(BENCHMARK_OS_WINDOWS) + std::string str; + static constexpr int COUNT = MAX_COMPUTERNAME_LENGTH + 1; + TCHAR hostname[COUNT] = {'\0'}; + DWORD DWCOUNT = COUNT; + if (!GetComputerName(hostname, &DWCOUNT)) return std::string(""); +#ifndef UNICODE + str = std::string(hostname, DWCOUNT); +#else + // `WideCharToMultiByte` returns `0` when conversion fails. + int len = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, hostname, + DWCOUNT, NULL, 0, NULL, NULL); + str.resize(len); + WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, hostname, DWCOUNT, &str[0], + str.size(), NULL, NULL); +#endif + return str; +#elif defined(BENCHMARK_OS_QURT) + std::string str = "Hexagon DSP"; + qurt_arch_version_t arch_version_struct; + if (qurt_sysenv_get_arch_version(&arch_version_struct) == QURT_EOK) { + str += " v"; + str += std::to_string(arch_version_struct.arch_version); + } + return str; +#else +#ifndef HOST_NAME_MAX +#ifdef BENCHMARK_HAS_SYSCTL // BSD/Mac doesn't have HOST_NAME_MAX defined +#define HOST_NAME_MAX 64 +#elif defined(BENCHMARK_OS_NACL) +#define HOST_NAME_MAX 64 +#elif defined(BENCHMARK_OS_QNX) +#define HOST_NAME_MAX 154 +#elif defined(BENCHMARK_OS_RTEMS) +#define HOST_NAME_MAX 256 +#elif defined(BENCHMARK_OS_SOLARIS) +#define HOST_NAME_MAX MAXHOSTNAMELEN +#elif defined(BENCHMARK_OS_ZOS) +#define HOST_NAME_MAX _POSIX_HOST_NAME_MAX +#else +#pragma message("HOST_NAME_MAX not defined. using 64") +#define HOST_NAME_MAX 64 +#endif +#endif // def HOST_NAME_MAX + char hostname[HOST_NAME_MAX]; + int retVal = gethostname(hostname, HOST_NAME_MAX); + if (retVal != 0) return std::string(""); + return std::string(hostname); +#endif // Catch-all POSIX block. +} + +int GetNumCPUsImpl() { +#ifdef BENCHMARK_HAS_SYSCTL + int num_cpu = -1; + if (GetSysctl("hw.ncpu", &num_cpu)) return num_cpu; + PrintErrorAndDie("Err: ", strerror(errno)); +#elif defined(BENCHMARK_OS_WINDOWS) + SYSTEM_INFO sysinfo; + // Use memset as opposed to = {} to avoid GCC missing initializer false + // positives. + std::memset(&sysinfo, 0, sizeof(SYSTEM_INFO)); + GetSystemInfo(&sysinfo); + // number of logical processors in the current group + return static_cast(sysinfo.dwNumberOfProcessors); +#elif defined(BENCHMARK_OS_SOLARIS) + // Returns -1 in case of a failure. + long num_cpu = sysconf(_SC_NPROCESSORS_ONLN); + if (num_cpu < 0) { + PrintErrorAndDie("sysconf(_SC_NPROCESSORS_ONLN) failed with error: ", + strerror(errno)); + } + return (int)num_cpu; +#elif defined(BENCHMARK_OS_QNX) + return static_cast(_syspage_ptr->num_cpu); +#elif defined(BENCHMARK_OS_QURT) + qurt_sysenv_max_hthreads_t hardware_threads; + if (qurt_sysenv_get_max_hw_threads(&hardware_threads) != QURT_EOK) { + hardware_threads.max_hthreads = 1; + } + return hardware_threads.max_hthreads; +#else + int num_cpus = 0; + int max_id = -1; + std::ifstream f("/proc/cpuinfo"); + if (!f.is_open()) { + std::cerr << "Failed to open /proc/cpuinfo\n"; + return -1; + } +#if defined(__alpha__) + const std::string Key = "cpus detected"; +#else + const std::string Key = "processor"; +#endif + std::string ln; + while (std::getline(f, ln)) { + if (ln.empty()) continue; + std::size_t split_idx = ln.find(':'); + std::string value; +#if defined(__s390__) + // s390 has another format in /proc/cpuinfo + // it needs to be parsed differently + if (split_idx != std::string::npos) + value = ln.substr(Key.size() + 1, split_idx - Key.size() - 1); +#else + if (split_idx != std::string::npos) value = ln.substr(split_idx + 1); +#endif + if (ln.size() >= Key.size() && ln.compare(0, Key.size(), Key) == 0) { + num_cpus++; + if (!value.empty()) { + const int cur_id = benchmark::stoi(value); + max_id = std::max(cur_id, max_id); + } + } + } + if (f.bad()) { + PrintErrorAndDie("Failure reading /proc/cpuinfo"); + } + if (!f.eof()) { + PrintErrorAndDie("Failed to read to end of /proc/cpuinfo"); + } + f.close(); + + if ((max_id + 1) != num_cpus) { + fprintf(stderr, + "CPU ID assignments in /proc/cpuinfo seem messed up." + " This is usually caused by a bad BIOS.\n"); + } + return num_cpus; +#endif + BENCHMARK_UNREACHABLE(); +} + +int GetNumCPUs() { + const int num_cpus = GetNumCPUsImpl(); + if (num_cpus < 1) { + std::cerr << "Unable to extract number of CPUs. If your platform uses " + "/proc/cpuinfo, custom support may need to be added.\n"; + } + return num_cpus; +} + +class ThreadAffinityGuard final { + public: + ThreadAffinityGuard() : reset_affinity(SetAffinity()) { + if (!reset_affinity) + std::cerr << "***WARNING*** Failed to set thread affinity. Estimated CPU " + "frequency may be incorrect." + << std::endl; + } + + ~ThreadAffinityGuard() { + if (!reset_affinity) return; + +#if defined(BENCHMARK_HAS_PTHREAD_AFFINITY) + int ret = pthread_setaffinity_np(self, sizeof(previous_affinity), + &previous_affinity); + if (ret == 0) return; +#elif defined(BENCHMARK_OS_WINDOWS_WIN32) + DWORD_PTR ret = SetThreadAffinityMask(self, previous_affinity); + if (ret != 0) return; +#endif // def BENCHMARK_HAS_PTHREAD_AFFINITY + PrintErrorAndDie("Failed to reset thread affinity"); + } + + ThreadAffinityGuard(ThreadAffinityGuard&&) = delete; + ThreadAffinityGuard(const ThreadAffinityGuard&) = delete; + ThreadAffinityGuard& operator=(ThreadAffinityGuard&&) = delete; + ThreadAffinityGuard& operator=(const ThreadAffinityGuard&) = delete; + + private: + bool SetAffinity() { +#if defined(BENCHMARK_HAS_PTHREAD_AFFINITY) + int ret; + self = pthread_self(); + ret = pthread_getaffinity_np(self, sizeof(previous_affinity), + &previous_affinity); + if (ret != 0) return false; + + cpu_set_t affinity; + memcpy(&affinity, &previous_affinity, sizeof(affinity)); + + bool is_first_cpu = true; + + for (int i = 0; i < CPU_SETSIZE; ++i) + if (CPU_ISSET(i, &affinity)) { + if (is_first_cpu) + is_first_cpu = false; + else + CPU_CLR(i, &affinity); + } + + if (is_first_cpu) return false; + + ret = pthread_setaffinity_np(self, sizeof(affinity), &affinity); + return ret == 0; +#elif defined(BENCHMARK_OS_WINDOWS_WIN32) + self = GetCurrentThread(); + DWORD_PTR mask = static_cast(1) << GetCurrentProcessorNumber(); + previous_affinity = SetThreadAffinityMask(self, mask); + return previous_affinity != 0; +#else + return false; +#endif // def BENCHMARK_HAS_PTHREAD_AFFINITY + } + +#if defined(BENCHMARK_HAS_PTHREAD_AFFINITY) + pthread_t self; + cpu_set_t previous_affinity; +#elif defined(BENCHMARK_OS_WINDOWS_WIN32) + HANDLE self; + DWORD_PTR previous_affinity; +#endif // def BENCHMARK_HAS_PTHREAD_AFFINITY + bool reset_affinity; +}; + +double GetCPUCyclesPerSecond(CPUInfo::Scaling scaling) { + // Currently, scaling is only used on linux path here, + // suppress diagnostics about it being unused on other paths. + (void)scaling; + +#if defined BENCHMARK_OS_LINUX || defined BENCHMARK_OS_CYGWIN + long freq; + + // If the kernel is exporting the tsc frequency use that. There are issues + // where cpuinfo_max_freq cannot be relied on because the BIOS may be + // exporintg an invalid p-state (on x86) or p-states may be used to put the + // processor in a new mode (turbo mode). Essentially, those frequencies + // cannot always be relied upon. The same reasons apply to /proc/cpuinfo as + // well. + if (ReadFromFile("/sys/devices/system/cpu/cpu0/tsc_freq_khz", &freq) + // If CPU scaling is disabled, use the *current* frequency. + // Note that we specifically don't want to read cpuinfo_cur_freq, + // because it is only readable by root. + || (scaling == CPUInfo::Scaling::DISABLED && + ReadFromFile("/sys/devices/system/cpu/cpu0/cpufreq/scaling_cur_freq", + &freq)) + // Otherwise, if CPU scaling may be in effect, we want to use + // the *maximum* frequency, not whatever CPU speed some random processor + // happens to be using now. + || ReadFromFile("/sys/devices/system/cpu/cpu0/cpufreq/cpuinfo_max_freq", + &freq)) { + // The value is in kHz (as the file name suggests). For example, on a + // 2GHz warpstation, the file contains the value "2000000". + return static_cast(freq) * 1000.0; + } + + const double error_value = -1; + double bogo_clock = error_value; + + std::ifstream f("/proc/cpuinfo"); + if (!f.is_open()) { + std::cerr << "failed to open /proc/cpuinfo\n"; + return error_value; + } + + auto StartsWithKey = [](std::string const& Value, std::string const& Key) { + if (Key.size() > Value.size()) return false; + auto Cmp = [&](char X, char Y) { + return std::tolower(X) == std::tolower(Y); + }; + return std::equal(Key.begin(), Key.end(), Value.begin(), Cmp); + }; + + std::string ln; + while (std::getline(f, ln)) { + if (ln.empty()) continue; + std::size_t split_idx = ln.find(':'); + std::string value; + if (split_idx != std::string::npos) value = ln.substr(split_idx + 1); + // When parsing the "cpu MHz" and "bogomips" (fallback) entries, we only + // accept positive values. Some environments (virtual machines) report zero, + // which would cause infinite looping in WallTime_Init. + if (StartsWithKey(ln, "cpu MHz")) { + if (!value.empty()) { + double cycles_per_second = benchmark::stod(value) * 1000000.0; + if (cycles_per_second > 0) return cycles_per_second; + } + } else if (StartsWithKey(ln, "bogomips")) { + if (!value.empty()) { + bogo_clock = benchmark::stod(value) * 1000000.0; + if (bogo_clock < 0.0) bogo_clock = error_value; + } + } + } + if (f.bad()) { + std::cerr << "Failure reading /proc/cpuinfo\n"; + return error_value; + } + if (!f.eof()) { + std::cerr << "Failed to read to end of /proc/cpuinfo\n"; + return error_value; + } + f.close(); + // If we found the bogomips clock, but nothing better, we'll use it (but + // we're not happy about it); otherwise, fallback to the rough estimation + // below. + if (bogo_clock >= 0.0) return bogo_clock; + +#elif defined BENCHMARK_HAS_SYSCTL + constexpr auto* freqStr = +#if defined(BENCHMARK_OS_FREEBSD) || defined(BENCHMARK_OS_NETBSD) + "machdep.tsc_freq"; +#elif defined BENCHMARK_OS_OPENBSD + "hw.cpuspeed"; +#elif defined BENCHMARK_OS_DRAGONFLY + "hw.tsc_frequency"; +#else + "hw.cpufrequency"; +#endif + unsigned long long hz = 0; +#if defined BENCHMARK_OS_OPENBSD + if (GetSysctl(freqStr, &hz)) return static_cast(hz * 1000000); +#else + if (GetSysctl(freqStr, &hz)) return static_cast(hz); +#endif + fprintf(stderr, "Unable to determine clock rate from sysctl: %s: %s\n", + freqStr, strerror(errno)); + fprintf(stderr, + "This does not affect benchmark measurements, only the " + "metadata output.\n"); + +#elif defined BENCHMARK_OS_WINDOWS_WIN32 + // In NT, read MHz from the registry. If we fail to do so or we're in win9x + // then make a crude estimate. + DWORD data, data_size = sizeof(data); + if (IsWindowsXPOrGreater() && + SUCCEEDED( + SHGetValueA(HKEY_LOCAL_MACHINE, + "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", + "~MHz", nullptr, &data, &data_size))) + return static_cast(static_cast(data) * + static_cast(1000 * 1000)); // was mhz +#elif defined(BENCHMARK_OS_SOLARIS) + kstat_ctl_t* kc = kstat_open(); + if (!kc) { + std::cerr << "failed to open /dev/kstat\n"; + return -1; + } + kstat_t* ksp = kstat_lookup(kc, const_cast("cpu_info"), -1, + const_cast("cpu_info0")); + if (!ksp) { + std::cerr << "failed to lookup in /dev/kstat\n"; + return -1; + } + if (kstat_read(kc, ksp, NULL) < 0) { + std::cerr << "failed to read from /dev/kstat\n"; + return -1; + } + kstat_named_t* knp = (kstat_named_t*)kstat_data_lookup( + ksp, const_cast("current_clock_Hz")); + if (!knp) { + std::cerr << "failed to lookup data in /dev/kstat\n"; + return -1; + } + if (knp->data_type != KSTAT_DATA_UINT64) { + std::cerr << "current_clock_Hz is of unexpected data type: " + << knp->data_type << "\n"; + return -1; + } + double clock_hz = knp->value.ui64; + kstat_close(kc); + return clock_hz; +#elif defined(BENCHMARK_OS_QNX) + return static_cast( + static_cast(SYSPAGE_ENTRY(cpuinfo)->speed) * + static_cast(1000 * 1000)); +#elif defined(BENCHMARK_OS_QURT) + // QuRT doesn't provide any API to query Hexagon frequency. + return 1000000000; +#endif + // If we've fallen through, attempt to roughly estimate the CPU clock rate. + + // Make sure to use the same cycle counter when starting and stopping the + // cycle timer. We just pin the current thread to a cpu in the previous + // affinity set. + ThreadAffinityGuard affinity_guard; + + static constexpr double estimate_time_s = 1.0; + const double start_time = ChronoClockNow(); + const auto start_ticks = cycleclock::Now(); + + // Impose load instead of calling sleep() to make sure the cycle counter + // works. + using PRNG = std::minstd_rand; + using Result = PRNG::result_type; + PRNG rng(static_cast(start_ticks)); + + Result state = 0; + + do { + static constexpr size_t batch_size = 10000; + rng.discard(batch_size); + state += rng(); + + } while (ChronoClockNow() - start_time < estimate_time_s); + + DoNotOptimize(state); + + const auto end_ticks = cycleclock::Now(); + const double end_time = ChronoClockNow(); + + return static_cast(end_ticks - start_ticks) / (end_time - start_time); + // Reset the affinity of current thread when the lifetime of affinity_guard + // ends. +} + +std::vector GetLoadAvg() { +#if (defined BENCHMARK_OS_FREEBSD || defined(BENCHMARK_OS_LINUX) || \ + defined BENCHMARK_OS_MACOSX || defined BENCHMARK_OS_NETBSD || \ + defined BENCHMARK_OS_OPENBSD || defined BENCHMARK_OS_DRAGONFLY) && \ + !(defined(__ANDROID__) && __ANDROID_API__ < 29) + static constexpr int kMaxSamples = 3; + std::vector res(kMaxSamples, 0.0); + const size_t nelem = static_cast(getloadavg(res.data(), kMaxSamples)); + if (nelem < 1) { + res.clear(); + } else { + res.resize(nelem); + } + return res; +#else + return {}; +#endif +} + +} // end namespace + +const CPUInfo& CPUInfo::Get() { + static const CPUInfo* info = new CPUInfo(); + return *info; +} + +CPUInfo::CPUInfo() + : num_cpus(GetNumCPUs()), + scaling(CpuScaling(num_cpus)), + cycles_per_second(GetCPUCyclesPerSecond(scaling)), + caches(GetCacheSizes()), + load_avg(GetLoadAvg()) {} + +const SystemInfo& SystemInfo::Get() { + static const SystemInfo* info = new SystemInfo(); + return *info; +} + +SystemInfo::SystemInfo() : name(GetSystemName()) {} +} // end namespace benchmark diff --git a/third_party/benchmark/src/thread_manager.h b/third_party/benchmark/src/thread_manager.h new file mode 100644 index 0000000..819b3c4 --- /dev/null +++ b/third_party/benchmark/src/thread_manager.h @@ -0,0 +1,63 @@ +#ifndef BENCHMARK_THREAD_MANAGER_H +#define BENCHMARK_THREAD_MANAGER_H + +#include + +#include "benchmark/benchmark.h" +#include "mutex.h" + +namespace benchmark { +namespace internal { + +class ThreadManager { + public: + explicit ThreadManager(int num_threads) + : alive_threads_(num_threads), start_stop_barrier_(num_threads) {} + + Mutex& GetBenchmarkMutex() const RETURN_CAPABILITY(benchmark_mutex_) { + return benchmark_mutex_; + } + + bool StartStopBarrier() EXCLUDES(end_cond_mutex_) { + return start_stop_barrier_.wait(); + } + + void NotifyThreadComplete() EXCLUDES(end_cond_mutex_) { + start_stop_barrier_.removeThread(); + if (--alive_threads_ == 0) { + MutexLock lock(end_cond_mutex_); + end_condition_.notify_all(); + } + } + + void WaitForAllThreads() EXCLUDES(end_cond_mutex_) { + MutexLock lock(end_cond_mutex_); + end_condition_.wait(lock.native_handle(), + [this]() { return alive_threads_ == 0; }); + } + + struct Result { + IterationCount iterations = 0; + double real_time_used = 0; + double cpu_time_used = 0; + double manual_time_used = 0; + int64_t complexity_n = 0; + std::string report_label_; + std::string skip_message_; + internal::Skipped skipped_ = internal::NotSkipped; + UserCounters counters; + }; + GUARDED_BY(GetBenchmarkMutex()) Result results; + + private: + mutable Mutex benchmark_mutex_; + std::atomic alive_threads_; + Barrier start_stop_barrier_; + Mutex end_cond_mutex_; + Condition end_condition_; +}; + +} // namespace internal +} // namespace benchmark + +#endif // BENCHMARK_THREAD_MANAGER_H diff --git a/third_party/benchmark/src/thread_timer.h b/third_party/benchmark/src/thread_timer.h new file mode 100644 index 0000000..eb23f59 --- /dev/null +++ b/third_party/benchmark/src/thread_timer.h @@ -0,0 +1,86 @@ +#ifndef BENCHMARK_THREAD_TIMER_H +#define BENCHMARK_THREAD_TIMER_H + +#include "check.h" +#include "timers.h" + +namespace benchmark { +namespace internal { + +class ThreadTimer { + explicit ThreadTimer(bool measure_process_cpu_time_) + : measure_process_cpu_time(measure_process_cpu_time_) {} + + public: + static ThreadTimer Create() { + return ThreadTimer(/*measure_process_cpu_time_=*/false); + } + static ThreadTimer CreateProcessCpuTime() { + return ThreadTimer(/*measure_process_cpu_time_=*/true); + } + + // Called by each thread + void StartTimer() { + running_ = true; + start_real_time_ = ChronoClockNow(); + start_cpu_time_ = ReadCpuTimerOfChoice(); + } + + // Called by each thread + void StopTimer() { + BM_CHECK(running_); + running_ = false; + real_time_used_ += ChronoClockNow() - start_real_time_; + // Floating point error can result in the subtraction producing a negative + // time. Guard against that. + cpu_time_used_ += + std::max(ReadCpuTimerOfChoice() - start_cpu_time_, 0); + } + + // Called by each thread + void SetIterationTime(double seconds) { manual_time_used_ += seconds; } + + bool running() const { return running_; } + + // REQUIRES: timer is not running + double real_time_used() const { + BM_CHECK(!running_); + return real_time_used_; + } + + // REQUIRES: timer is not running + double cpu_time_used() const { + BM_CHECK(!running_); + return cpu_time_used_; + } + + // REQUIRES: timer is not running + double manual_time_used() const { + BM_CHECK(!running_); + return manual_time_used_; + } + + private: + double ReadCpuTimerOfChoice() const { + if (measure_process_cpu_time) return ProcessCPUUsage(); + return ThreadCPUUsage(); + } + + // should the thread, or the process, time be measured? + const bool measure_process_cpu_time; + + bool running_ = false; // Is the timer running + double start_real_time_ = 0; // If running_ + double start_cpu_time_ = 0; // If running_ + + // Accumulated time so far (does not contain current slice if running_) + double real_time_used_ = 0; + double cpu_time_used_ = 0; + // Manually set iteration time. User sets this with SetIterationTime(seconds). + double manual_time_used_ = 0; +}; + +} // namespace internal +} // namespace benchmark + +#endif // BENCHMARK_THREAD_TIMER_H diff --git a/third_party/benchmark/src/timers.cc b/third_party/benchmark/src/timers.cc new file mode 100644 index 0000000..7ba540b --- /dev/null +++ b/third_party/benchmark/src/timers.cc @@ -0,0 +1,284 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "timers.h" + +#include "internal_macros.h" + +#ifdef BENCHMARK_OS_WINDOWS +#include +#undef StrCat // Don't let StrCat in string_util.h be renamed to lstrcatA +#include +#include +#else +#include +#if !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +#include +#endif +#include +#include // this header must be included before 'sys/sysctl.h' to avoid compilation error on FreeBSD +#include +#if defined BENCHMARK_OS_FREEBSD || defined BENCHMARK_OS_DRAGONFLY || \ + defined BENCHMARK_OS_MACOSX +#include +#endif +#if defined(BENCHMARK_OS_MACOSX) +#include +#include +#include +#endif +#if defined(BENCHMARK_OS_QURT) +#include +#endif +#endif + +#ifdef BENCHMARK_OS_EMSCRIPTEN +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check.h" +#include "log.h" +#include "string_util.h" + +namespace benchmark { + +// Suppress unused warnings on helper functions. +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wunused-function" +#endif +#if defined(__NVCOMPILER) +#pragma diag_suppress declared_but_not_referenced +#endif + +namespace { +#if defined(BENCHMARK_OS_WINDOWS) +double MakeTime(FILETIME const& kernel_time, FILETIME const& user_time) { + ULARGE_INTEGER kernel; + ULARGE_INTEGER user; + kernel.HighPart = kernel_time.dwHighDateTime; + kernel.LowPart = kernel_time.dwLowDateTime; + user.HighPart = user_time.dwHighDateTime; + user.LowPart = user_time.dwLowDateTime; + return (static_cast(kernel.QuadPart) + + static_cast(user.QuadPart)) * + 1e-7; +} +#elif !defined(BENCHMARK_OS_FUCHSIA) && !defined(BENCHMARK_OS_QURT) +double MakeTime(struct rusage const& ru) { + return (static_cast(ru.ru_utime.tv_sec) + + static_cast(ru.ru_utime.tv_usec) * 1e-6 + + static_cast(ru.ru_stime.tv_sec) + + static_cast(ru.ru_stime.tv_usec) * 1e-6); +} +#endif +#if defined(BENCHMARK_OS_MACOSX) +double MakeTime(thread_basic_info_data_t const& info) { + return (static_cast(info.user_time.seconds) + + static_cast(info.user_time.microseconds) * 1e-6 + + static_cast(info.system_time.seconds) + + static_cast(info.system_time.microseconds) * 1e-6); +} +#endif +#if defined(CLOCK_PROCESS_CPUTIME_ID) || defined(CLOCK_THREAD_CPUTIME_ID) +double MakeTime(struct timespec const& ts) { + return static_cast(ts.tv_sec) + + (static_cast(ts.tv_nsec) * 1e-9); +} +#endif + +BENCHMARK_NORETURN static void DiagnoseAndExit(const char* msg) { + std::cerr << "ERROR: " << msg << std::endl; + std::exit(EXIT_FAILURE); +} + +} // end namespace + +double ProcessCPUUsage() { +#if defined(BENCHMARK_OS_WINDOWS) + HANDLE proc = GetCurrentProcess(); + FILETIME creation_time; + FILETIME exit_time; + FILETIME kernel_time; + FILETIME user_time; + if (GetProcessTimes(proc, &creation_time, &exit_time, &kernel_time, + &user_time)) + return MakeTime(kernel_time, user_time); + DiagnoseAndExit("GetProccessTimes() failed"); +#elif defined(BENCHMARK_OS_QURT) + // Note that qurt_timer_get_ticks() is no longer documented as of SDK 5.3.0, + // and doesn't appear to work on at least some devices (eg Samsung S22), + // so let's use the actually-documented and apparently-equivalent + // qurt_sysclock_get_hw_ticks() call instead. + return static_cast( + qurt_timer_timetick_to_us(qurt_sysclock_get_hw_ticks())) * + 1.0e-6; +#elif defined(BENCHMARK_OS_EMSCRIPTEN) + // clock_gettime(CLOCK_PROCESS_CPUTIME_ID, ...) returns 0 on Emscripten. + // Use Emscripten-specific API. Reported CPU time would be exactly the + // same as total time, but this is ok because there aren't long-latency + // synchronous system calls in Emscripten. + return emscripten_get_now() * 1e-3; +#elif defined(CLOCK_PROCESS_CPUTIME_ID) && !defined(BENCHMARK_OS_MACOSX) + // FIXME We want to use clock_gettime, but its not available in MacOS 10.11. + // See https://github.com/google/benchmark/pull/292 + struct timespec spec; + if (clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &spec) == 0) + return MakeTime(spec); + DiagnoseAndExit("clock_gettime(CLOCK_PROCESS_CPUTIME_ID, ...) failed"); +#else + struct rusage ru; + if (getrusage(RUSAGE_SELF, &ru) == 0) return MakeTime(ru); + DiagnoseAndExit("getrusage(RUSAGE_SELF, ...) failed"); +#endif +} + +double ThreadCPUUsage() { +#if defined(BENCHMARK_OS_WINDOWS) + HANDLE this_thread = GetCurrentThread(); + FILETIME creation_time; + FILETIME exit_time; + FILETIME kernel_time; + FILETIME user_time; + GetThreadTimes(this_thread, &creation_time, &exit_time, &kernel_time, + &user_time); + return MakeTime(kernel_time, user_time); +#elif defined(BENCHMARK_OS_QURT) + // Note that qurt_timer_get_ticks() is no longer documented as of SDK 5.3.0, + // and doesn't appear to work on at least some devices (eg Samsung S22), + // so let's use the actually-documented and apparently-equivalent + // qurt_sysclock_get_hw_ticks() call instead. + return static_cast( + qurt_timer_timetick_to_us(qurt_sysclock_get_hw_ticks())) * + 1.0e-6; +#elif defined(BENCHMARK_OS_MACOSX) + // FIXME We want to use clock_gettime, but its not available in MacOS 10.11. + // See https://github.com/google/benchmark/pull/292 + mach_msg_type_number_t count = THREAD_BASIC_INFO_COUNT; + thread_basic_info_data_t info; + mach_port_t thread = pthread_mach_thread_np(pthread_self()); + if (thread_info(thread, THREAD_BASIC_INFO, + reinterpret_cast(&info), + &count) == KERN_SUCCESS) { + return MakeTime(info); + } + DiagnoseAndExit("ThreadCPUUsage() failed when evaluating thread_info"); +#elif defined(BENCHMARK_OS_EMSCRIPTEN) + // Emscripten doesn't support traditional threads + return ProcessCPUUsage(); +#elif defined(BENCHMARK_OS_RTEMS) + // RTEMS doesn't support CLOCK_THREAD_CPUTIME_ID. See + // https://github.com/RTEMS/rtems/blob/master/cpukit/posix/src/clockgettime.c + return ProcessCPUUsage(); +#elif defined(BENCHMARK_OS_ZOS) + // z/OS doesn't support CLOCK_THREAD_CPUTIME_ID. + return ProcessCPUUsage(); +#elif defined(BENCHMARK_OS_SOLARIS) + struct rusage ru; + if (getrusage(RUSAGE_LWP, &ru) == 0) return MakeTime(ru); + DiagnoseAndExit("getrusage(RUSAGE_LWP, ...) failed"); +#elif defined(CLOCK_THREAD_CPUTIME_ID) + struct timespec ts; + if (clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts) == 0) return MakeTime(ts); + DiagnoseAndExit("clock_gettime(CLOCK_THREAD_CPUTIME_ID, ...) failed"); +#else +#error Per-thread timing is not available on your system. +#endif +} + +std::string LocalDateTimeString() { + // Write the local time in RFC3339 format yyyy-mm-ddTHH:MM:SS+/-HH:MM. + typedef std::chrono::system_clock Clock; + std::time_t now = Clock::to_time_t(Clock::now()); + const std::size_t kTzOffsetLen = 6; + const std::size_t kTimestampLen = 19; + + std::size_t tz_len; + std::size_t timestamp_len; + long int offset_minutes; + char tz_offset_sign = '+'; + // tz_offset is set in one of three ways: + // * strftime with %z - This either returns empty or the ISO 8601 time. The + // maximum length an + // ISO 8601 string can be is 7 (e.g. -03:30, plus trailing zero). + // * snprintf with %c%02li:%02li - The maximum length is 41 (one for %c, up to + // 19 for %02li, + // one for :, up to 19 %02li, plus trailing zero). + // * A fixed string of "-00:00". The maximum length is 7 (-00:00, plus + // trailing zero). + // + // Thus, the maximum size this needs to be is 41. + char tz_offset[41]; + // Long enough buffer to avoid format-overflow warnings + char storage[128]; + +#if defined(BENCHMARK_OS_WINDOWS) + std::tm* timeinfo_p = ::localtime(&now); +#else + std::tm timeinfo; + std::tm* timeinfo_p = &timeinfo; + ::localtime_r(&now, &timeinfo); +#endif + + tz_len = std::strftime(tz_offset, sizeof(tz_offset), "%z", timeinfo_p); + + if (tz_len < kTzOffsetLen && tz_len > 1) { + // Timezone offset was written. strftime writes offset as +HHMM or -HHMM, + // RFC3339 specifies an offset as +HH:MM or -HH:MM. To convert, we parse + // the offset as an integer, then reprint it to a string. + + offset_minutes = ::strtol(tz_offset, NULL, 10); + if (offset_minutes < 0) { + offset_minutes *= -1; + tz_offset_sign = '-'; + } + + tz_len = static_cast( + ::snprintf(tz_offset, sizeof(tz_offset), "%c%02li:%02li", + tz_offset_sign, offset_minutes / 100, offset_minutes % 100)); + BM_CHECK(tz_len == kTzOffsetLen); + ((void)tz_len); // Prevent unused variable warning in optimized build. + } else { + // Unknown offset. RFC3339 specifies that unknown local offsets should be + // written as UTC time with -00:00 timezone. +#if defined(BENCHMARK_OS_WINDOWS) + // Potential race condition if another thread calls localtime or gmtime. + timeinfo_p = ::gmtime(&now); +#else + ::gmtime_r(&now, &timeinfo); +#endif + + strncpy(tz_offset, "-00:00", kTzOffsetLen + 1); + } + + timestamp_len = + std::strftime(storage, sizeof(storage), "%Y-%m-%dT%H:%M:%S", timeinfo_p); + BM_CHECK(timestamp_len == kTimestampLen); + // Prevent unused variable warning in optimized build. + ((void)kTimestampLen); + + std::strncat(storage, tz_offset, sizeof(storage) - timestamp_len - 1); + return std::string(storage); +} + +} // end namespace benchmark diff --git a/third_party/benchmark/src/timers.h b/third_party/benchmark/src/timers.h new file mode 100644 index 0000000..690086b --- /dev/null +++ b/third_party/benchmark/src/timers.h @@ -0,0 +1,75 @@ +#ifndef BENCHMARK_TIMERS_H +#define BENCHMARK_TIMERS_H + +#include +#include + +namespace benchmark { + +// Return the CPU usage of the current process +double ProcessCPUUsage(); + +// Return the CPU usage of the children of the current process +double ChildrenCPUUsage(); + +// Return the CPU usage of the current thread +double ThreadCPUUsage(); + +#if defined(BENCHMARK_OS_QURT) + +// std::chrono::now() can return 0 on some Hexagon devices; +// this reads the value of a 56-bit, 19.2MHz hardware counter +// and converts it to seconds. Unlike std::chrono, this doesn't +// return an absolute time, but since ChronoClockNow() is only used +// to compute elapsed time, this shouldn't matter. +struct QuRTClock { + typedef uint64_t rep; + typedef std::ratio<1, 19200000> period; + typedef std::chrono::duration duration; + typedef std::chrono::time_point time_point; + static const bool is_steady = false; + + static time_point now() { + unsigned long long count; + asm volatile(" %0 = c31:30 " : "=r"(count)); + return time_point(static_cast(count)); + } +}; + +#else + +#if defined(HAVE_STEADY_CLOCK) +template +struct ChooseSteadyClock { + typedef std::chrono::high_resolution_clock type; +}; + +template <> +struct ChooseSteadyClock { + typedef std::chrono::steady_clock type; +}; +#endif // HAVE_STEADY_CLOCK + +#endif + +struct ChooseClockType { +#if defined(BENCHMARK_OS_QURT) + typedef QuRTClock type; +#elif defined(HAVE_STEADY_CLOCK) + typedef ChooseSteadyClock<>::type type; +#else + typedef std::chrono::high_resolution_clock type; +#endif +}; + +inline double ChronoClockNow() { + typedef ChooseClockType::type ClockType; + using FpSeconds = std::chrono::duration; + return FpSeconds(ClockType::now().time_since_epoch()).count(); +} + +std::string LocalDateTimeString(); + +} // end namespace benchmark + +#endif // BENCHMARK_TIMERS_H diff --git a/third_party/benchmark/test/AssemblyTests.cmake b/third_party/benchmark/test/AssemblyTests.cmake new file mode 100644 index 0000000..c43c711 --- /dev/null +++ b/third_party/benchmark/test/AssemblyTests.cmake @@ -0,0 +1,67 @@ +set(CLANG_SUPPORTED_VERSION "5.0.0") +set(GCC_SUPPORTED_VERSION "5.5.0") + +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL ${CLANG_SUPPORTED_VERSION}) + message (WARNING + "Unsupported Clang version " ${CMAKE_CXX_COMPILER_VERSION} + ". Expected is " ${CLANG_SUPPORTED_VERSION} + ". Assembly tests may be broken.") + endif() +elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU") + if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL ${GCC_SUPPORTED_VERSION}) + message (WARNING + "Unsupported GCC version " ${CMAKE_CXX_COMPILER_VERSION} + ". Expected is " ${GCC_SUPPORTED_VERSION} + ". Assembly tests may be broken.") + endif() +else() + message (WARNING "Unsupported compiler. Assembly tests may be broken.") +endif() + +include(split_list) + +set(ASM_TEST_FLAGS "") +check_cxx_compiler_flag(-O3 BENCHMARK_HAS_O3_FLAG) +if (BENCHMARK_HAS_O3_FLAG) + list(APPEND ASM_TEST_FLAGS -O3) +endif() + +check_cxx_compiler_flag(-g0 BENCHMARK_HAS_G0_FLAG) +if (BENCHMARK_HAS_G0_FLAG) + list(APPEND ASM_TEST_FLAGS -g0) +endif() + +check_cxx_compiler_flag(-fno-stack-protector BENCHMARK_HAS_FNO_STACK_PROTECTOR_FLAG) +if (BENCHMARK_HAS_FNO_STACK_PROTECTOR_FLAG) + list(APPEND ASM_TEST_FLAGS -fno-stack-protector) +endif() + +split_list(ASM_TEST_FLAGS) +string(TOUPPER "${CMAKE_CXX_COMPILER_ID}" ASM_TEST_COMPILER) + +macro(add_filecheck_test name) + cmake_parse_arguments(ARG "" "" "CHECK_PREFIXES" ${ARGV}) + add_library(${name} OBJECT ${name}.cc) + target_link_libraries(${name} PRIVATE benchmark::benchmark) + set_target_properties(${name} PROPERTIES COMPILE_FLAGS "-S ${ASM_TEST_FLAGS}") + set(ASM_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${name}.s") + add_custom_target(copy_${name} ALL + COMMAND ${PROJECT_SOURCE_DIR}/tools/strip_asm.py + $ + ${ASM_OUTPUT_FILE} + BYPRODUCTS ${ASM_OUTPUT_FILE}) + add_dependencies(copy_${name} ${name}) + if (NOT ARG_CHECK_PREFIXES) + set(ARG_CHECK_PREFIXES "CHECK") + endif() + foreach(prefix ${ARG_CHECK_PREFIXES}) + add_test(NAME run_${name}_${prefix} + COMMAND + ${LLVM_FILECHECK_EXE} ${name}.cc + --input-file=${ASM_OUTPUT_FILE} + --check-prefixes=CHECK,CHECK-${ASM_TEST_COMPILER} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + endforeach() +endmacro() + diff --git a/third_party/benchmark/test/BUILD b/third_party/benchmark/test/BUILD new file mode 100644 index 0000000..c1ca86b --- /dev/null +++ b/third_party/benchmark/test/BUILD @@ -0,0 +1,132 @@ +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +platform( + name = "windows", + constraint_values = [ + "@platforms//os:windows", + ], +) + +TEST_COPTS = [ + "-pedantic", + "-pedantic-errors", + "-std=c++17", + "-Wall", + "-Wconversion", + "-Wextra", + "-Wshadow", + # "-Wshorten-64-to-32", + "-Wfloat-equal", + "-fstrict-aliasing", + ## assert() are used a lot in tests upstream, which may be optimised out leading to + ## unused-variable warning. + "-Wno-unused-variable", + "-Werror=old-style-cast", +] + +# Some of the issues with DoNotOptimize only occur when optimization is enabled +PER_SRC_COPTS = { + "donotoptimize_test.cc": ["-O3"], +} + +TEST_ARGS = ["--benchmark_min_time=0.01s"] + +PER_SRC_TEST_ARGS = { + "user_counters_tabular_test.cc": ["--benchmark_counters_tabular=true"], + "repetitions_test.cc": [" --benchmark_repetitions=3"], + "spec_arg_test.cc": ["--benchmark_filter=BM_NotChosen"], + "spec_arg_verbosity_test.cc": ["--v=42"], + "complexity_test.cc": ["--benchmark_min_time=1000000x"], +} + +cc_library( + name = "output_test_helper", + testonly = 1, + srcs = ["output_test_helper.cc"], + hdrs = ["output_test.h"], + copts = select({ + "//:windows": [], + "//conditions:default": TEST_COPTS, + }), + deps = [ + "//:benchmark", + "//:benchmark_internal_headers", + ], +) + +# Tests that use gtest. These rely on `gtest_main`. +[ + cc_test( + name = test_src[:-len(".cc")], + size = "small", + srcs = [test_src], + copts = select({ + "//:windows": [], + "//conditions:default": TEST_COPTS, + }) + PER_SRC_COPTS.get(test_src, []), + deps = [ + "//:benchmark", + "//:benchmark_internal_headers", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], + ) + for test_src in glob(["*_gtest.cc"]) +] + +# Tests that do not use gtest. These have their own `main` defined. +[ + cc_test( + name = test_src[:-len(".cc")], + size = "small", + srcs = [test_src], + args = TEST_ARGS + PER_SRC_TEST_ARGS.get(test_src, []), + copts = select({ + "//:windows": [], + "//conditions:default": TEST_COPTS, + }) + PER_SRC_COPTS.get(test_src, []), + deps = [ + ":output_test_helper", + "//:benchmark", + "//:benchmark_internal_headers", + ], + # FIXME: Add support for assembly tests to bazel. + # See Issue #556 + # https://github.com/google/benchmark/issues/556 + ) + for test_src in glob( + ["*_test.cc"], + exclude = [ + "*_assembly_test.cc", + "cxx03_test.cc", + "link_main_test.cc", + ], + ) +] + +cc_test( + name = "cxx03_test", + size = "small", + srcs = ["cxx03_test.cc"], + copts = TEST_COPTS + ["-std=c++03"], + target_compatible_with = select({ + "//:windows": ["@platforms//:incompatible"], + "//conditions:default": [], + }), + deps = [ + ":output_test_helper", + "//:benchmark", + "//:benchmark_internal_headers", + ], +) + +cc_test( + name = "link_main_test", + size = "small", + srcs = ["link_main_test.cc"], + copts = select({ + "//:windows": [], + "//conditions:default": TEST_COPTS, + }), + deps = ["//:benchmark_main"], +) diff --git a/third_party/benchmark/test/CMakeLists.txt b/third_party/benchmark/test/CMakeLists.txt new file mode 100644 index 0000000..321e24d --- /dev/null +++ b/third_party/benchmark/test/CMakeLists.txt @@ -0,0 +1,322 @@ +# Enable the tests + +set(THREADS_PREFER_PTHREAD_FLAG ON) + +find_package(Threads REQUIRED) +include(CheckCXXCompilerFlag) + +add_cxx_compiler_flag(-Wno-unused-variable) + +# NOTE: Some tests use `` to perform the test. Therefore we must +# strip -DNDEBUG from the default CMake flags in DEBUG mode. +string(TOUPPER "${CMAKE_BUILD_TYPE}" uppercase_CMAKE_BUILD_TYPE) +if( NOT uppercase_CMAKE_BUILD_TYPE STREQUAL "DEBUG" ) + add_definitions( -UNDEBUG ) + add_definitions(-DTEST_BENCHMARK_LIBRARY_HAS_NO_ASSERTIONS) + # Also remove /D NDEBUG to avoid MSVC warnings about conflicting defines. + foreach (flags_var_to_scrub + CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS_MINSIZEREL) + string (REGEX REPLACE "(^| )[/-]D *NDEBUG($| )" " " + "${flags_var_to_scrub}" "${${flags_var_to_scrub}}") + endforeach() +endif() + +if (NOT BUILD_SHARED_LIBS) + add_definitions(-DBENCHMARK_STATIC_DEFINE) +endif() + +check_cxx_compiler_flag(-O3 BENCHMARK_HAS_O3_FLAG) +set(BENCHMARK_O3_FLAG "") +if (BENCHMARK_HAS_O3_FLAG) + set(BENCHMARK_O3_FLAG "-O3") +endif() + +# NOTE: These flags must be added after find_package(Threads REQUIRED) otherwise +# they will break the configuration check. +if (DEFINED BENCHMARK_CXX_LINKER_FLAGS) + list(APPEND CMAKE_EXE_LINKER_FLAGS ${BENCHMARK_CXX_LINKER_FLAGS}) +endif() + +add_library(output_test_helper STATIC output_test_helper.cc output_test.h) +target_link_libraries(output_test_helper PRIVATE benchmark::benchmark) + +macro(compile_benchmark_test name) + add_executable(${name} "${name}.cc") + target_link_libraries(${name} benchmark::benchmark ${CMAKE_THREAD_LIBS_INIT}) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "NVHPC") + target_compile_options( ${name} PRIVATE --diag_suppress partial_override ) + endif() +endmacro(compile_benchmark_test) + +macro(compile_benchmark_test_with_main name) + add_executable(${name} "${name}.cc") + target_link_libraries(${name} benchmark::benchmark_main) +endmacro(compile_benchmark_test_with_main) + +macro(compile_output_test name) + add_executable(${name} "${name}.cc" output_test.h) + target_link_libraries(${name} output_test_helper benchmark::benchmark_main + ${BENCHMARK_CXX_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) +endmacro(compile_output_test) + +macro(benchmark_add_test) + add_test(${ARGV}) + if(WIN32 AND BUILD_SHARED_LIBS) + cmake_parse_arguments(TEST "" "NAME" "" ${ARGN}) + set_tests_properties(${TEST_NAME} PROPERTIES ENVIRONMENT_MODIFICATION "PATH=path_list_prepend:$") + endif() +endmacro(benchmark_add_test) + +# Demonstration executable +compile_benchmark_test(benchmark_test) +benchmark_add_test(NAME benchmark COMMAND benchmark_test --benchmark_min_time=0.01s) + +compile_benchmark_test(spec_arg_test) +benchmark_add_test(NAME spec_arg COMMAND spec_arg_test --benchmark_filter=BM_NotChosen) + +compile_benchmark_test(spec_arg_verbosity_test) +benchmark_add_test(NAME spec_arg_verbosity COMMAND spec_arg_verbosity_test --v=42) + +compile_benchmark_test(benchmark_setup_teardown_test) +benchmark_add_test(NAME benchmark_setup_teardown COMMAND benchmark_setup_teardown_test) + +compile_benchmark_test(filter_test) +macro(add_filter_test name filter expect) + benchmark_add_test(NAME ${name} COMMAND filter_test --benchmark_min_time=0.01s --benchmark_filter=${filter} ${expect}) + benchmark_add_test(NAME ${name}_list_only COMMAND filter_test --benchmark_list_tests --benchmark_filter=${filter} ${expect}) +endmacro(add_filter_test) + +compile_benchmark_test(benchmark_min_time_flag_time_test) +benchmark_add_test(NAME min_time_flag_time COMMAND benchmark_min_time_flag_time_test) + +compile_benchmark_test(benchmark_min_time_flag_iters_test) +benchmark_add_test(NAME min_time_flag_iters COMMAND benchmark_min_time_flag_iters_test) + +add_filter_test(filter_simple "Foo" 3) +add_filter_test(filter_simple_negative "-Foo" 2) +add_filter_test(filter_suffix "BM_.*" 4) +add_filter_test(filter_suffix_negative "-BM_.*" 1) +add_filter_test(filter_regex_all ".*" 5) +add_filter_test(filter_regex_all_negative "-.*" 0) +add_filter_test(filter_regex_blank "" 5) +add_filter_test(filter_regex_blank_negative "-" 0) +add_filter_test(filter_regex_none "monkey" 0) +add_filter_test(filter_regex_none_negative "-monkey" 5) +add_filter_test(filter_regex_wildcard ".*Foo.*" 3) +add_filter_test(filter_regex_wildcard_negative "-.*Foo.*" 2) +add_filter_test(filter_regex_begin "^BM_.*" 4) +add_filter_test(filter_regex_begin_negative "-^BM_.*" 1) +add_filter_test(filter_regex_begin2 "^N" 1) +add_filter_test(filter_regex_begin2_negative "-^N" 4) +add_filter_test(filter_regex_end ".*Ba$" 1) +add_filter_test(filter_regex_end_negative "-.*Ba$" 4) + +compile_benchmark_test(options_test) +benchmark_add_test(NAME options_benchmarks COMMAND options_test --benchmark_min_time=0.01s) + +compile_benchmark_test(basic_test) +benchmark_add_test(NAME basic_benchmark COMMAND basic_test --benchmark_min_time=0.01s) + +compile_output_test(repetitions_test) +benchmark_add_test(NAME repetitions_benchmark COMMAND repetitions_test --benchmark_min_time=0.01s --benchmark_repetitions=3) + +compile_benchmark_test(diagnostics_test) +benchmark_add_test(NAME diagnostics_test COMMAND diagnostics_test --benchmark_min_time=0.01s) + +compile_benchmark_test(skip_with_error_test) +benchmark_add_test(NAME skip_with_error_test COMMAND skip_with_error_test --benchmark_min_time=0.01s) + +compile_benchmark_test(donotoptimize_test) +# Enable errors for deprecated deprecations (DoNotOptimize(Tp const& value)). +check_cxx_compiler_flag(-Werror=deprecated-declarations BENCHMARK_HAS_DEPRECATED_DECLARATIONS_FLAG) +if (BENCHMARK_HAS_DEPRECATED_DECLARATIONS_FLAG) + target_compile_options (donotoptimize_test PRIVATE "-Werror=deprecated-declarations") +endif() +# Some of the issues with DoNotOptimize only occur when optimization is enabled +check_cxx_compiler_flag(-O3 BENCHMARK_HAS_O3_FLAG) +if (BENCHMARK_HAS_O3_FLAG) + set_target_properties(donotoptimize_test PROPERTIES COMPILE_FLAGS "-O3") +endif() +benchmark_add_test(NAME donotoptimize_test COMMAND donotoptimize_test --benchmark_min_time=0.01s) + +compile_benchmark_test(fixture_test) +benchmark_add_test(NAME fixture_test COMMAND fixture_test --benchmark_min_time=0.01s) + +compile_benchmark_test(register_benchmark_test) +benchmark_add_test(NAME register_benchmark_test COMMAND register_benchmark_test --benchmark_min_time=0.01s) + +compile_benchmark_test(map_test) +benchmark_add_test(NAME map_test COMMAND map_test --benchmark_min_time=0.01s) + +compile_benchmark_test(multiple_ranges_test) +benchmark_add_test(NAME multiple_ranges_test COMMAND multiple_ranges_test --benchmark_min_time=0.01s) + +compile_benchmark_test(args_product_test) +benchmark_add_test(NAME args_product_test COMMAND args_product_test --benchmark_min_time=0.01s) + +compile_benchmark_test_with_main(link_main_test) +benchmark_add_test(NAME link_main_test COMMAND link_main_test --benchmark_min_time=0.01s) + +compile_output_test(reporter_output_test) +benchmark_add_test(NAME reporter_output_test COMMAND reporter_output_test --benchmark_min_time=0.01s) + +compile_output_test(templated_fixture_test) +benchmark_add_test(NAME templated_fixture_test COMMAND templated_fixture_test --benchmark_min_time=0.01s) + +compile_output_test(user_counters_test) +benchmark_add_test(NAME user_counters_test COMMAND user_counters_test --benchmark_min_time=0.01s) + +compile_output_test(perf_counters_test) +benchmark_add_test(NAME perf_counters_test COMMAND perf_counters_test --benchmark_min_time=0.01s --benchmark_perf_counters=CYCLES,INSTRUCTIONS) + +compile_output_test(internal_threading_test) +benchmark_add_test(NAME internal_threading_test COMMAND internal_threading_test --benchmark_min_time=0.01s) + +compile_output_test(report_aggregates_only_test) +benchmark_add_test(NAME report_aggregates_only_test COMMAND report_aggregates_only_test --benchmark_min_time=0.01s) + +compile_output_test(display_aggregates_only_test) +benchmark_add_test(NAME display_aggregates_only_test COMMAND display_aggregates_only_test --benchmark_min_time=0.01s) + +compile_output_test(user_counters_tabular_test) +benchmark_add_test(NAME user_counters_tabular_test COMMAND user_counters_tabular_test --benchmark_counters_tabular=true --benchmark_min_time=0.01s) + +compile_output_test(user_counters_thousands_test) +benchmark_add_test(NAME user_counters_thousands_test COMMAND user_counters_thousands_test --benchmark_min_time=0.01s) + +compile_output_test(memory_manager_test) +benchmark_add_test(NAME memory_manager_test COMMAND memory_manager_test --benchmark_min_time=0.01s) + +compile_output_test(profiler_manager_test) +benchmark_add_test(NAME profiler_manager_test COMMAND profiler_manager_test --benchmark_min_time=0.01s) + +# MSVC does not allow to set the language standard to C++98/03. +if(NOT (MSVC OR CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC")) + compile_benchmark_test(cxx03_test) + set_target_properties(cxx03_test + PROPERTIES + CXX_STANDARD 98 + CXX_STANDARD_REQUIRED YES) + # libstdc++ provides different definitions within between dialects. When + # LTO is enabled and -Werror is specified GCC diagnoses this ODR violation + # causing the test to fail to compile. To prevent this we explicitly disable + # the warning. + check_cxx_compiler_flag(-Wno-odr BENCHMARK_HAS_WNO_ODR) + check_cxx_compiler_flag(-Wno-lto-type-mismatch BENCHMARK_HAS_WNO_LTO_TYPE_MISMATCH) + # Cannot set_target_properties multiple times here because the warnings will + # be overwritten on each call + set (DISABLE_LTO_WARNINGS "") + if (BENCHMARK_HAS_WNO_ODR) + set(DISABLE_LTO_WARNINGS "${DISABLE_LTO_WARNINGS} -Wno-odr") + endif() + if (BENCHMARK_HAS_WNO_LTO_TYPE_MISMATCH) + set(DISABLE_LTO_WARNINGS "${DISABLE_LTO_WARNINGS} -Wno-lto-type-mismatch") + endif() + set_target_properties(cxx03_test PROPERTIES LINK_FLAGS "${DISABLE_LTO_WARNINGS}") + benchmark_add_test(NAME cxx03 COMMAND cxx03_test --benchmark_min_time=0.01s) +endif() + +compile_output_test(complexity_test) +benchmark_add_test(NAME complexity_benchmark COMMAND complexity_test --benchmark_min_time=1000000x) + +############################################################################### +# GoogleTest Unit Tests +############################################################################### + +if (BENCHMARK_ENABLE_GTEST_TESTS) + macro(compile_gtest name) + add_executable(${name} "${name}.cc") + target_link_libraries(${name} benchmark::benchmark + gmock_main ${CMAKE_THREAD_LIBS_INIT}) + endmacro(compile_gtest) + + macro(add_gtest name) + compile_gtest(${name}) + benchmark_add_test(NAME ${name} COMMAND ${name}) + if(WIN32 AND BUILD_SHARED_LIBS) + set_tests_properties(${name} PROPERTIES + ENVIRONMENT_MODIFICATION "PATH=path_list_prepend:$;PATH=path_list_prepend:$" + ) + endif() + endmacro() + + add_gtest(benchmark_gtest) + add_gtest(benchmark_name_gtest) + add_gtest(benchmark_random_interleaving_gtest) + add_gtest(commandlineflags_gtest) + add_gtest(statistics_gtest) + add_gtest(string_util_gtest) + add_gtest(perf_counters_gtest) + add_gtest(time_unit_gtest) + add_gtest(min_time_parse_gtest) + add_gtest(profiler_manager_gtest) +endif(BENCHMARK_ENABLE_GTEST_TESTS) + +############################################################################### +# Assembly Unit Tests +############################################################################### + +if (BENCHMARK_ENABLE_ASSEMBLY_TESTS) + if (NOT LLVM_FILECHECK_EXE) + message(FATAL_ERROR "LLVM FileCheck is required when including this file") + endif() + include(AssemblyTests.cmake) + add_filecheck_test(donotoptimize_assembly_test) + add_filecheck_test(state_assembly_test) + add_filecheck_test(clobber_memory_assembly_test) +endif() + + + +############################################################################### +# Code Coverage Configuration +############################################################################### + +# Add the coverage command(s) +if(CMAKE_BUILD_TYPE) + string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) +endif() +if (${CMAKE_BUILD_TYPE_LOWER} MATCHES "coverage") + find_program(GCOV gcov) + find_program(LCOV lcov) + find_program(GENHTML genhtml) + find_program(CTEST ctest) + if (GCOV AND LCOV AND GENHTML AND CTEST AND HAVE_CXX_FLAG_COVERAGE) + add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/lcov/index.html + COMMAND ${LCOV} -q -z -d . + COMMAND ${LCOV} -q --no-external -c -b "${CMAKE_SOURCE_DIR}" -d . -o before.lcov -i + COMMAND ${CTEST} --force-new-ctest-process + COMMAND ${LCOV} -q --no-external -c -b "${CMAKE_SOURCE_DIR}" -d . -o after.lcov + COMMAND ${LCOV} -q -a before.lcov -a after.lcov --output-file final.lcov + COMMAND ${LCOV} -q -r final.lcov "'${CMAKE_SOURCE_DIR}/test/*'" -o final.lcov + COMMAND ${GENHTML} final.lcov -o lcov --demangle-cpp --sort -p "${CMAKE_BINARY_DIR}" -t benchmark + DEPENDS filter_test benchmark_test options_test basic_test fixture_test cxx03_test complexity_test + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running LCOV" + ) + add_custom_target(coverage + DEPENDS ${CMAKE_BINARY_DIR}/lcov/index.html + COMMENT "LCOV report at lcov/index.html" + ) + message(STATUS "Coverage command added") + else() + if (HAVE_CXX_FLAG_COVERAGE) + set(CXX_FLAG_COVERAGE_MESSAGE supported) + else() + set(CXX_FLAG_COVERAGE_MESSAGE unavailable) + endif() + message(WARNING + "Coverage not available:\n" + " gcov: ${GCOV}\n" + " lcov: ${LCOV}\n" + " genhtml: ${GENHTML}\n" + " ctest: ${CTEST}\n" + " --coverage flag: ${CXX_FLAG_COVERAGE_MESSAGE}") + endif() +endif() diff --git a/third_party/benchmark/test/args_product_test.cc b/third_party/benchmark/test/args_product_test.cc new file mode 100644 index 0000000..63b8b71 --- /dev/null +++ b/third_party/benchmark/test/args_product_test.cc @@ -0,0 +1,77 @@ +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +class ArgsProductFixture : public ::benchmark::Fixture { + public: + ArgsProductFixture() + : expectedValues({{0, 100, 2000, 30000}, + {1, 15, 3, 8}, + {1, 15, 3, 9}, + {1, 15, 7, 8}, + {1, 15, 7, 9}, + {1, 15, 10, 8}, + {1, 15, 10, 9}, + {2, 15, 3, 8}, + {2, 15, 3, 9}, + {2, 15, 7, 8}, + {2, 15, 7, 9}, + {2, 15, 10, 8}, + {2, 15, 10, 9}, + {4, 5, 6, 11}}) {} + + void SetUp(const ::benchmark::State& state) override { + std::vector ranges = {state.range(0), state.range(1), + state.range(2), state.range(3)}; + + assert(expectedValues.find(ranges) != expectedValues.end()); + + actualValues.insert(ranges); + } + + // NOTE: This is not TearDown as we want to check after _all_ runs are + // complete. + ~ArgsProductFixture() override { + if (actualValues != expectedValues) { + std::cout << "EXPECTED\n"; + for (const auto& v : expectedValues) { + std::cout << "{"; + for (int64_t iv : v) { + std::cout << iv << ", "; + } + std::cout << "}\n"; + } + std::cout << "ACTUAL\n"; + for (const auto& v : actualValues) { + std::cout << "{"; + for (int64_t iv : v) { + std::cout << iv << ", "; + } + std::cout << "}\n"; + } + } + } + + std::set> expectedValues; + std::set> actualValues; +}; + +BENCHMARK_DEFINE_F(ArgsProductFixture, Empty)(benchmark::State& state) { + for (auto _ : state) { + int64_t product = + state.range(0) * state.range(1) * state.range(2) * state.range(3); + for (int64_t x = 0; x < product; x++) { + benchmark::DoNotOptimize(x); + } + } +} + +BENCHMARK_REGISTER_F(ArgsProductFixture, Empty) + ->Args({0, 100, 2000, 30000}) + ->ArgsProduct({{1, 2}, {15}, {3, 7, 10}, {8, 9}}) + ->Args({4, 5, 6, 11}); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/basic_test.cc b/third_party/benchmark/test/basic_test.cc new file mode 100644 index 0000000..c25bec7 --- /dev/null +++ b/third_party/benchmark/test/basic_test.cc @@ -0,0 +1,180 @@ + +#include "benchmark/benchmark.h" + +#define BASIC_BENCHMARK_TEST(x) BENCHMARK(x)->Arg(8)->Arg(512)->Arg(8192) + +void BM_empty(benchmark::State& state) { + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_empty); +BENCHMARK(BM_empty)->ThreadPerCpu(); + +void BM_spin_empty(benchmark::State& state) { + for (auto _ : state) { + for (auto x = 0; x < state.range(0); ++x) { + benchmark::DoNotOptimize(x); + } + } +} +BASIC_BENCHMARK_TEST(BM_spin_empty); +BASIC_BENCHMARK_TEST(BM_spin_empty)->ThreadPerCpu(); + +void BM_spin_pause_before(benchmark::State& state) { + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + for (auto _ : state) { + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + } +} +BASIC_BENCHMARK_TEST(BM_spin_pause_before); +BASIC_BENCHMARK_TEST(BM_spin_pause_before)->ThreadPerCpu(); + +void BM_spin_pause_during(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + state.ResumeTiming(); + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + } +} +BASIC_BENCHMARK_TEST(BM_spin_pause_during); +BASIC_BENCHMARK_TEST(BM_spin_pause_during)->ThreadPerCpu(); + +void BM_pause_during(benchmark::State& state) { + for (auto _ : state) { + state.PauseTiming(); + state.ResumeTiming(); + } +} +BENCHMARK(BM_pause_during); +BENCHMARK(BM_pause_during)->ThreadPerCpu(); +BENCHMARK(BM_pause_during)->UseRealTime(); +BENCHMARK(BM_pause_during)->UseRealTime()->ThreadPerCpu(); + +void BM_spin_pause_after(benchmark::State& state) { + for (auto _ : state) { + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + } + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } +} +BASIC_BENCHMARK_TEST(BM_spin_pause_after); +BASIC_BENCHMARK_TEST(BM_spin_pause_after)->ThreadPerCpu(); + +void BM_spin_pause_before_and_after(benchmark::State& state) { + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + for (auto _ : state) { + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } + } + for (auto i = 0; i < state.range(0); ++i) { + benchmark::DoNotOptimize(i); + } +} +BASIC_BENCHMARK_TEST(BM_spin_pause_before_and_after); +BASIC_BENCHMARK_TEST(BM_spin_pause_before_and_after)->ThreadPerCpu(); + +void BM_empty_stop_start(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_empty_stop_start); +BENCHMARK(BM_empty_stop_start)->ThreadPerCpu(); + +void BM_KeepRunning(benchmark::State& state) { + benchmark::IterationCount iter_count = 0; + assert(iter_count == state.iterations()); + while (state.KeepRunning()) { + ++iter_count; + } + assert(iter_count == state.iterations()); +} +BENCHMARK(BM_KeepRunning); + +void BM_KeepRunningBatch(benchmark::State& state) { + // Choose a batch size >1000 to skip the typical runs with iteration + // targets of 10, 100 and 1000. If these are not actually skipped the + // bug would be detectable as consecutive runs with the same iteration + // count. Below we assert that this does not happen. + const benchmark::IterationCount batch_size = 1009; + + static benchmark::IterationCount prior_iter_count = 0; + benchmark::IterationCount iter_count = 0; + while (state.KeepRunningBatch(batch_size)) { + iter_count += batch_size; + } + assert(state.iterations() == iter_count); + + // Verify that the iteration count always increases across runs (see + // comment above). + assert(iter_count == batch_size // max_iterations == 1 + || iter_count > prior_iter_count); // max_iterations > batch_size + prior_iter_count = iter_count; +} +// Register with a fixed repetition count to establish the invariant that +// the iteration count should always change across runs. This overrides +// the --benchmark_repetitions command line flag, which would otherwise +// cause this test to fail if set > 1. +BENCHMARK(BM_KeepRunningBatch)->Repetitions(1); + +void BM_RangedFor(benchmark::State& state) { + benchmark::IterationCount iter_count = 0; + for (auto _ : state) { + ++iter_count; + } + assert(iter_count == state.max_iterations); +} +BENCHMARK(BM_RangedFor); + +#ifdef BENCHMARK_HAS_CXX11 +template +void BM_OneTemplateFunc(benchmark::State& state) { + auto arg = state.range(0); + T sum = 0; + for (auto _ : state) { + sum += static_cast(arg); + } +} +BENCHMARK(BM_OneTemplateFunc)->Arg(1); +BENCHMARK(BM_OneTemplateFunc)->Arg(1); + +template +void BM_TwoTemplateFunc(benchmark::State& state) { + auto arg = state.range(0); + A sum = 0; + B prod = 1; + for (auto _ : state) { + sum += static_cast(arg); + prod *= static_cast(arg); + } +} +BENCHMARK(BM_TwoTemplateFunc)->Arg(1); +BENCHMARK(BM_TwoTemplateFunc)->Arg(1); + +#endif // BENCHMARK_HAS_CXX11 + +// Ensure that StateIterator provides all the necessary typedefs required to +// instantiate std::iterator_traits. +static_assert( + std::is_same::value_type, + typename benchmark::State::StateIterator::value_type>::value, + ""); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/benchmark_gtest.cc b/third_party/benchmark/test/benchmark_gtest.cc new file mode 100644 index 0000000..0aa2552 --- /dev/null +++ b/third_party/benchmark/test/benchmark_gtest.cc @@ -0,0 +1,169 @@ +#include +#include +#include + +#include "../src/benchmark_register.h" +#include "benchmark/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace benchmark { +namespace internal { + +namespace { + +TEST(AddRangeTest, Simple) { + std::vector dst; + AddRange(&dst, 1, 2, 2); + EXPECT_THAT(dst, testing::ElementsAre(1, 2)); +} + +TEST(AddRangeTest, Simple64) { + std::vector dst; + AddRange(&dst, static_cast(1), static_cast(2), 2); + EXPECT_THAT(dst, testing::ElementsAre(1, 2)); +} + +TEST(AddRangeTest, Advanced) { + std::vector dst; + AddRange(&dst, 5, 15, 2); + EXPECT_THAT(dst, testing::ElementsAre(5, 8, 15)); +} + +TEST(AddRangeTest, Advanced64) { + std::vector dst; + AddRange(&dst, static_cast(5), static_cast(15), 2); + EXPECT_THAT(dst, testing::ElementsAre(5, 8, 15)); +} + +TEST(AddRangeTest, FullRange8) { + std::vector dst; + AddRange(&dst, int8_t{1}, std::numeric_limits::max(), 8); + EXPECT_THAT( + dst, testing::ElementsAre(int8_t{1}, int8_t{8}, int8_t{64}, int8_t{127})); +} + +TEST(AddRangeTest, FullRange64) { + std::vector dst; + AddRange(&dst, int64_t{1}, std::numeric_limits::max(), 1024); + EXPECT_THAT( + dst, testing::ElementsAre(1LL, 1024LL, 1048576LL, 1073741824LL, + 1099511627776LL, 1125899906842624LL, + 1152921504606846976LL, 9223372036854775807LL)); +} + +TEST(AddRangeTest, NegativeRanges) { + std::vector dst; + AddRange(&dst, -8, 0, 2); + EXPECT_THAT(dst, testing::ElementsAre(-8, -4, -2, -1, 0)); +} + +TEST(AddRangeTest, StrictlyNegative) { + std::vector dst; + AddRange(&dst, -8, -1, 2); + EXPECT_THAT(dst, testing::ElementsAre(-8, -4, -2, -1)); +} + +TEST(AddRangeTest, SymmetricNegativeRanges) { + std::vector dst; + AddRange(&dst, -8, 8, 2); + EXPECT_THAT(dst, testing::ElementsAre(-8, -4, -2, -1, 0, 1, 2, 4, 8)); +} + +TEST(AddRangeTest, SymmetricNegativeRangesOddMult) { + std::vector dst; + AddRange(&dst, -30, 32, 5); + EXPECT_THAT(dst, testing::ElementsAre(-30, -25, -5, -1, 0, 1, 5, 25, 32)); +} + +TEST(AddRangeTest, NegativeRangesAsymmetric) { + std::vector dst; + AddRange(&dst, -3, 5, 2); + EXPECT_THAT(dst, testing::ElementsAre(-3, -2, -1, 0, 1, 2, 4, 5)); +} + +TEST(AddRangeTest, NegativeRangesLargeStep) { + // Always include -1, 0, 1 when crossing zero. + std::vector dst; + AddRange(&dst, -8, 8, 10); + EXPECT_THAT(dst, testing::ElementsAre(-8, -1, 0, 1, 8)); +} + +TEST(AddRangeTest, ZeroOnlyRange) { + std::vector dst; + AddRange(&dst, 0, 0, 2); + EXPECT_THAT(dst, testing::ElementsAre(0)); +} + +TEST(AddRangeTest, ZeroStartingRange) { + std::vector dst; + AddRange(&dst, 0, 2, 2); + EXPECT_THAT(dst, testing::ElementsAre(0, 1, 2)); +} + +TEST(AddRangeTest, NegativeRange64) { + std::vector dst; + AddRange(&dst, -4, 4, 2); + EXPECT_THAT(dst, testing::ElementsAre(-4, -2, -1, 0, 1, 2, 4)); +} + +TEST(AddRangeTest, NegativeRangePreservesExistingOrder) { + // If elements already exist in the range, ensure we don't change + // their ordering by adding negative values. + std::vector dst = {1, 2, 3}; + AddRange(&dst, -2, 2, 2); + EXPECT_THAT(dst, testing::ElementsAre(1, 2, 3, -2, -1, 0, 1, 2)); +} + +TEST(AddRangeTest, FullNegativeRange64) { + std::vector dst; + const auto min = std::numeric_limits::min(); + const auto max = std::numeric_limits::max(); + AddRange(&dst, min, max, 1024); + EXPECT_THAT( + dst, testing::ElementsAreArray(std::vector{ + min, -1152921504606846976LL, -1125899906842624LL, + -1099511627776LL, -1073741824LL, -1048576LL, -1024LL, -1LL, 0LL, + 1LL, 1024LL, 1048576LL, 1073741824LL, 1099511627776LL, + 1125899906842624LL, 1152921504606846976LL, max})); +} + +TEST(AddRangeTest, Simple8) { + std::vector dst; + AddRange(&dst, int8_t{1}, int8_t{8}, int8_t{2}); + EXPECT_THAT(dst, + testing::ElementsAre(int8_t{1}, int8_t{2}, int8_t{4}, int8_t{8})); +} + +TEST(AddCustomContext, Simple) { + std::map *&global_context = GetGlobalContext(); + EXPECT_THAT(global_context, nullptr); + + AddCustomContext("foo", "bar"); + AddCustomContext("baz", "qux"); + + EXPECT_THAT(*global_context, + testing::UnorderedElementsAre(testing::Pair("foo", "bar"), + testing::Pair("baz", "qux"))); + + delete global_context; + global_context = nullptr; +} + +TEST(AddCustomContext, DuplicateKey) { + std::map *&global_context = GetGlobalContext(); + EXPECT_THAT(global_context, nullptr); + + AddCustomContext("foo", "bar"); + AddCustomContext("foo", "qux"); + + EXPECT_THAT(*global_context, + testing::UnorderedElementsAre(testing::Pair("foo", "bar"))); + + delete global_context; + global_context = nullptr; +} + +} // namespace +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/test/benchmark_min_time_flag_iters_test.cc b/third_party/benchmark/test/benchmark_min_time_flag_iters_test.cc new file mode 100644 index 0000000..3de93a7 --- /dev/null +++ b/third_party/benchmark/test/benchmark_min_time_flag_iters_test.cc @@ -0,0 +1,66 @@ +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +// Tests that we can specify the number of iterations with +// --benchmark_min_time=x. +namespace { + +class TestReporter : public benchmark::ConsoleReporter { + public: + virtual bool ReportContext(const Context& context) BENCHMARK_OVERRIDE { + return ConsoleReporter::ReportContext(context); + }; + + virtual void ReportRuns(const std::vector& report) BENCHMARK_OVERRIDE { + assert(report.size() == 1); + iter_nums_.push_back(report[0].iterations); + ConsoleReporter::ReportRuns(report); + }; + + TestReporter() {} + + virtual ~TestReporter() {} + + const std::vector& GetIters() const { + return iter_nums_; + } + + private: + std::vector iter_nums_; +}; + +} // end namespace + +static void BM_MyBench(benchmark::State& state) { + for (auto s : state) { + } +} +BENCHMARK(BM_MyBench); + +int main(int argc, char** argv) { + // Make a fake argv and append the new --benchmark_min_time= to it. + int fake_argc = argc + 1; + const char** fake_argv = new const char*[static_cast(fake_argc)]; + for (int i = 0; i < argc; ++i) fake_argv[i] = argv[i]; + fake_argv[argc] = "--benchmark_min_time=4x"; + + benchmark::Initialize(&fake_argc, const_cast(fake_argv)); + + TestReporter test_reporter; + const size_t returned_count = + benchmark::RunSpecifiedBenchmarks(&test_reporter, "BM_MyBench"); + assert(returned_count == 1); + + // Check the executed iters. + const std::vector iters = test_reporter.GetIters(); + assert(!iters.empty() && iters[0] == 4); + + delete[] fake_argv; + return 0; +} diff --git a/third_party/benchmark/test/benchmark_min_time_flag_time_test.cc b/third_party/benchmark/test/benchmark_min_time_flag_time_test.cc new file mode 100644 index 0000000..04a82eb --- /dev/null +++ b/third_party/benchmark/test/benchmark_min_time_flag_time_test.cc @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +// Tests that we can specify the min time with +// --benchmark_min_time= (no suffix needed) OR +// --benchmark_min_time=s +namespace { + +// This is from benchmark.h +typedef int64_t IterationCount; + +class TestReporter : public benchmark::ConsoleReporter { + public: + virtual bool ReportContext(const Context& context) BENCHMARK_OVERRIDE { + return ConsoleReporter::ReportContext(context); + }; + + virtual void ReportRuns(const std::vector& report) BENCHMARK_OVERRIDE { + assert(report.size() == 1); + ConsoleReporter::ReportRuns(report); + }; + + virtual void ReportRunsConfig(double min_time, bool /* has_explicit_iters */, + IterationCount /* iters */) BENCHMARK_OVERRIDE { + min_times_.push_back(min_time); + } + + TestReporter() {} + + virtual ~TestReporter() {} + + const std::vector& GetMinTimes() const { return min_times_; } + + private: + std::vector min_times_; +}; + +bool AlmostEqual(double a, double b) { + return std::fabs(a - b) < std::numeric_limits::epsilon(); +} + +void DoTestHelper(int* argc, const char** argv, double expected) { + benchmark::Initialize(argc, const_cast(argv)); + + TestReporter test_reporter; + const size_t returned_count = + benchmark::RunSpecifiedBenchmarks(&test_reporter, "BM_MyBench"); + assert(returned_count == 1); + + // Check the min_time + const std::vector& min_times = test_reporter.GetMinTimes(); + assert(!min_times.empty() && AlmostEqual(min_times[0], expected)); +} + +} // end namespace + +static void BM_MyBench(benchmark::State& state) { + for (auto s : state) { + } +} +BENCHMARK(BM_MyBench); + +int main(int argc, char** argv) { + // Make a fake argv and append the new --benchmark_min_time= to it. + int fake_argc = argc + 1; + const char** fake_argv = new const char*[static_cast(fake_argc)]; + + for (int i = 0; i < argc; ++i) fake_argv[i] = argv[i]; + + const char* no_suffix = "--benchmark_min_time=4"; + const char* with_suffix = "--benchmark_min_time=4.0s"; + double expected = 4.0; + + fake_argv[argc] = no_suffix; + DoTestHelper(&fake_argc, fake_argv, expected); + + fake_argv[argc] = with_suffix; + DoTestHelper(&fake_argc, fake_argv, expected); + + delete[] fake_argv; + return 0; +} diff --git a/third_party/benchmark/test/benchmark_name_gtest.cc b/third_party/benchmark/test/benchmark_name_gtest.cc new file mode 100644 index 0000000..0a6746d --- /dev/null +++ b/third_party/benchmark/test/benchmark_name_gtest.cc @@ -0,0 +1,82 @@ +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" + +namespace { + +using namespace benchmark; +using namespace benchmark::internal; + +TEST(BenchmarkNameTest, Empty) { + const auto name = BenchmarkName(); + EXPECT_EQ(name.str(), std::string()); +} + +TEST(BenchmarkNameTest, FunctionName) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + EXPECT_EQ(name.str(), "function_name"); +} + +TEST(BenchmarkNameTest, FunctionNameAndArgs) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.args = "some_args:3/4/5"; + EXPECT_EQ(name.str(), "function_name/some_args:3/4/5"); +} + +TEST(BenchmarkNameTest, MinTime) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.args = "some_args:3/4"; + name.min_time = "min_time:3.4s"; + EXPECT_EQ(name.str(), "function_name/some_args:3/4/min_time:3.4s"); +} + +TEST(BenchmarkNameTest, MinWarmUpTime) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.args = "some_args:3/4"; + name.min_warmup_time = "min_warmup_time:3.5s"; + EXPECT_EQ(name.str(), "function_name/some_args:3/4/min_warmup_time:3.5s"); +} + +TEST(BenchmarkNameTest, Iterations) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.min_time = "min_time:3.4s"; + name.iterations = "iterations:42"; + EXPECT_EQ(name.str(), "function_name/min_time:3.4s/iterations:42"); +} + +TEST(BenchmarkNameTest, Repetitions) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.min_time = "min_time:3.4s"; + name.repetitions = "repetitions:24"; + EXPECT_EQ(name.str(), "function_name/min_time:3.4s/repetitions:24"); +} + +TEST(BenchmarkNameTest, TimeType) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.min_time = "min_time:3.4s"; + name.time_type = "hammer_time"; + EXPECT_EQ(name.str(), "function_name/min_time:3.4s/hammer_time"); +} + +TEST(BenchmarkNameTest, Threads) { + auto name = BenchmarkName(); + name.function_name = "function_name"; + name.min_time = "min_time:3.4s"; + name.threads = "threads:256"; + EXPECT_EQ(name.str(), "function_name/min_time:3.4s/threads:256"); +} + +TEST(BenchmarkNameTest, TestEmptyFunctionName) { + auto name = BenchmarkName(); + name.args = "first:3/second:4"; + name.threads = "threads:22"; + EXPECT_EQ(name.str(), "first:3/second:4/threads:22"); +} + +} // end namespace diff --git a/third_party/benchmark/test/benchmark_random_interleaving_gtest.cc b/third_party/benchmark/test/benchmark_random_interleaving_gtest.cc new file mode 100644 index 0000000..7f20867 --- /dev/null +++ b/third_party/benchmark/test/benchmark_random_interleaving_gtest.cc @@ -0,0 +1,126 @@ +#include +#include +#include + +#include "../src/commandlineflags.h" +#include "../src/string_util.h" +#include "benchmark/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace benchmark { + +BM_DECLARE_bool(benchmark_enable_random_interleaving); +BM_DECLARE_string(benchmark_filter); +BM_DECLARE_int32(benchmark_repetitions); + +namespace internal { +namespace { + +class EventQueue : public std::queue { + public: + void Put(const std::string& event) { push(event); } + + void Clear() { + while (!empty()) { + pop(); + } + } + + std::string Get() { + std::string event = front(); + pop(); + return event; + } +}; + +EventQueue* queue = new EventQueue(); + +class NullReporter : public BenchmarkReporter { + public: + bool ReportContext(const Context& /*context*/) override { return true; } + void ReportRuns(const std::vector& /* report */) override {} +}; + +class BenchmarkTest : public testing::Test { + public: + static void SetupHook(int /* num_threads */) { queue->push("Setup"); } + + static void TeardownHook(int /* num_threads */) { queue->push("Teardown"); } + + void Execute(const std::string& pattern) { + queue->Clear(); + + std::unique_ptr reporter(new NullReporter()); + FLAGS_benchmark_filter = pattern; + RunSpecifiedBenchmarks(reporter.get()); + + queue->Put("DONE"); // End marker + } +}; + +void BM_Match1(benchmark::State& state) { + const int64_t arg = state.range(0); + + for (auto _ : state) { + } + queue->Put(StrFormat("BM_Match1/%d", static_cast(arg))); +} +BENCHMARK(BM_Match1) + ->Iterations(100) + ->Arg(1) + ->Arg(2) + ->Arg(3) + ->Range(10, 80) + ->Args({90}) + ->Args({100}); + +TEST_F(BenchmarkTest, Match1) { + Execute("BM_Match1"); + ASSERT_EQ("BM_Match1/1", queue->Get()); + ASSERT_EQ("BM_Match1/2", queue->Get()); + ASSERT_EQ("BM_Match1/3", queue->Get()); + ASSERT_EQ("BM_Match1/10", queue->Get()); + ASSERT_EQ("BM_Match1/64", queue->Get()); + ASSERT_EQ("BM_Match1/80", queue->Get()); + ASSERT_EQ("BM_Match1/90", queue->Get()); + ASSERT_EQ("BM_Match1/100", queue->Get()); + ASSERT_EQ("DONE", queue->Get()); +} + +TEST_F(BenchmarkTest, Match1WithRepetition) { + FLAGS_benchmark_repetitions = 2; + + Execute("BM_Match1/(64|80)"); + ASSERT_EQ("BM_Match1/64", queue->Get()); + ASSERT_EQ("BM_Match1/64", queue->Get()); + ASSERT_EQ("BM_Match1/80", queue->Get()); + ASSERT_EQ("BM_Match1/80", queue->Get()); + ASSERT_EQ("DONE", queue->Get()); +} + +TEST_F(BenchmarkTest, Match1WithRandomInterleaving) { + FLAGS_benchmark_enable_random_interleaving = true; + FLAGS_benchmark_repetitions = 100; + + std::map element_count; + std::map interleaving_count; + Execute("BM_Match1/(64|80)"); + for (int i = 0; i < 100; ++i) { + std::vector interleaving; + interleaving.push_back(queue->Get()); + interleaving.push_back(queue->Get()); + element_count[interleaving[0]]++; + element_count[interleaving[1]]++; + interleaving_count[StrFormat("%s,%s", interleaving[0].c_str(), + interleaving[1].c_str())]++; + } + EXPECT_EQ(element_count["BM_Match1/64"], 100) << "Unexpected repetitions."; + EXPECT_EQ(element_count["BM_Match1/80"], 100) << "Unexpected repetitions."; + EXPECT_GE(interleaving_count.size(), 2) << "Interleaving was not randomized."; + ASSERT_EQ("DONE", queue->Get()); +} + +} // namespace +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/test/benchmark_setup_teardown_test.cc b/third_party/benchmark/test/benchmark_setup_teardown_test.cc new file mode 100644 index 0000000..6c3cc2e --- /dev/null +++ b/third_party/benchmark/test/benchmark_setup_teardown_test.cc @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +// Test that Setup() and Teardown() are called exactly once +// for each benchmark run (single-threaded). +namespace singlethreaded { +static int setup_call = 0; +static int teardown_call = 0; +} // namespace singlethreaded +static void DoSetup1(const benchmark::State& state) { + ++singlethreaded::setup_call; + + // Setup/Teardown should never be called with any thread_idx != 0. + assert(state.thread_index() == 0); +} + +static void DoTeardown1(const benchmark::State& state) { + ++singlethreaded::teardown_call; + assert(state.thread_index() == 0); +} + +static void BM_with_setup(benchmark::State& state) { + for (auto s : state) { + } +} +BENCHMARK(BM_with_setup) + ->Arg(1) + ->Arg(3) + ->Arg(5) + ->Arg(7) + ->Iterations(100) + ->Setup(DoSetup1) + ->Teardown(DoTeardown1); + +// Test that Setup() and Teardown() are called once for each group of threads. +namespace concurrent { +static std::atomic setup_call(0); +static std::atomic teardown_call(0); +static std::atomic func_call(0); +} // namespace concurrent + +static void DoSetup2(const benchmark::State& state) { + concurrent::setup_call.fetch_add(1, std::memory_order_acquire); + assert(state.thread_index() == 0); +} + +static void DoTeardown2(const benchmark::State& state) { + concurrent::teardown_call.fetch_add(1, std::memory_order_acquire); + assert(state.thread_index() == 0); +} + +static void BM_concurrent(benchmark::State& state) { + for (auto s : state) { + } + concurrent::func_call.fetch_add(1, std::memory_order_acquire); +} + +BENCHMARK(BM_concurrent) + ->Setup(DoSetup2) + ->Teardown(DoTeardown2) + ->Iterations(100) + ->Threads(5) + ->Threads(10) + ->Threads(15); + +// Testing interaction with Fixture::Setup/Teardown +namespace fixture_interaction { +int setup = 0; +int fixture_setup = 0; +} // namespace fixture_interaction + +#define FIXTURE_BECHMARK_NAME MyFixture + +class FIXTURE_BECHMARK_NAME : public ::benchmark::Fixture { + public: + void SetUp(const ::benchmark::State&) override { + fixture_interaction::fixture_setup++; + } + + ~FIXTURE_BECHMARK_NAME() override {} +}; + +BENCHMARK_F(FIXTURE_BECHMARK_NAME, BM_WithFixture)(benchmark::State& st) { + for (auto _ : st) { + } +} + +static void DoSetupWithFixture(const benchmark::State&) { + fixture_interaction::setup++; +} + +BENCHMARK_REGISTER_F(FIXTURE_BECHMARK_NAME, BM_WithFixture) + ->Arg(1) + ->Arg(3) + ->Arg(5) + ->Arg(7) + ->Setup(DoSetupWithFixture) + ->Repetitions(1) + ->Iterations(100); + +// Testing repetitions. +namespace repetitions { +int setup = 0; +} + +static void DoSetupWithRepetitions(const benchmark::State&) { + repetitions::setup++; +} +static void BM_WithRep(benchmark::State& state) { + for (auto _ : state) { + } +} + +BENCHMARK(BM_WithRep) + ->Arg(1) + ->Arg(3) + ->Arg(5) + ->Arg(7) + ->Setup(DoSetupWithRepetitions) + ->Iterations(100) + ->Repetitions(4); + +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + + size_t ret = benchmark::RunSpecifiedBenchmarks("."); + assert(ret > 0); + + // Setup/Teardown is called once for each arg group (1,3,5,7). + assert(singlethreaded::setup_call == 4); + assert(singlethreaded::teardown_call == 4); + + // 3 group of threads calling this function (3,5,10). + assert(concurrent::setup_call.load(std::memory_order_relaxed) == 3); + assert(concurrent::teardown_call.load(std::memory_order_relaxed) == 3); + assert((5 + 10 + 15) == + concurrent::func_call.load(std::memory_order_relaxed)); + + // Setup is called 4 times, once for each arg group (1,3,5,7) + assert(fixture_interaction::setup == 4); + // Fixture::Setup is called every time the bm routine is run. + // The exact number is indeterministic, so we just assert that + // it's more than setup. + assert(fixture_interaction::fixture_setup > fixture_interaction::setup); + + // Setup is call once for each repetition * num_arg = 4 * 4 = 16. + assert(repetitions::setup == 16); + + return 0; +} diff --git a/third_party/benchmark/test/benchmark_test.cc b/third_party/benchmark/test/benchmark_test.cc new file mode 100644 index 0000000..8b14017 --- /dev/null +++ b/third_party/benchmark/test/benchmark_test.cc @@ -0,0 +1,300 @@ +#include "benchmark/benchmark.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__GNUC__) +#define BENCHMARK_NOINLINE __attribute__((noinline)) +#else +#define BENCHMARK_NOINLINE +#endif + +namespace { + +int BENCHMARK_NOINLINE Factorial(int n) { + return (n == 1) ? 1 : n * Factorial(n - 1); +} + +double CalculatePi(int depth) { + double pi = 0.0; + for (int i = 0; i < depth; ++i) { + double numerator = static_cast(((i % 2) * 2) - 1); + double denominator = static_cast((2 * i) - 1); + pi += numerator / denominator; + } + return (pi - 1.0) * 4; +} + +std::set ConstructRandomSet(int64_t size) { + std::set s; + for (int i = 0; i < size; ++i) s.insert(s.end(), i); + return s; +} + +std::mutex test_vector_mu; +std::vector* test_vector = nullptr; + +} // end namespace + +static void BM_Factorial(benchmark::State& state) { + int fac_42 = 0; + for (auto _ : state) fac_42 = Factorial(8); + // Prevent compiler optimizations + std::stringstream ss; + ss << fac_42; + state.SetLabel(ss.str()); +} +BENCHMARK(BM_Factorial); +BENCHMARK(BM_Factorial)->UseRealTime(); + +static void BM_CalculatePiRange(benchmark::State& state) { + double pi = 0.0; + for (auto _ : state) pi = CalculatePi(static_cast(state.range(0))); + std::stringstream ss; + ss << pi; + state.SetLabel(ss.str()); +} +BENCHMARK_RANGE(BM_CalculatePiRange, 1, 1024 * 1024); + +static void BM_CalculatePi(benchmark::State& state) { + static const int depth = 1024; + for (auto _ : state) { + double pi = CalculatePi(static_cast(depth)); + benchmark::DoNotOptimize(pi); + } +} +BENCHMARK(BM_CalculatePi)->Threads(8); +BENCHMARK(BM_CalculatePi)->ThreadRange(1, 32); +BENCHMARK(BM_CalculatePi)->ThreadPerCpu(); + +static void BM_SetInsert(benchmark::State& state) { + std::set data; + for (auto _ : state) { + state.PauseTiming(); + data = ConstructRandomSet(state.range(0)); + state.ResumeTiming(); + for (int j = 0; j < state.range(1); ++j) data.insert(rand()); + } + state.SetItemsProcessed(state.iterations() * state.range(1)); + state.SetBytesProcessed(state.iterations() * state.range(1) * + static_cast(sizeof(int))); +} + +// Test many inserts at once to reduce the total iterations needed. Otherwise, +// the slower, non-timed part of each iteration will make the benchmark take +// forever. +BENCHMARK(BM_SetInsert)->Ranges({{1 << 10, 8 << 10}, {128, 512}}); + +template +static void BM_Sequential(benchmark::State& state) { + ValueType v = 42; + for (auto _ : state) { + Container c; + for (int64_t i = state.range(0); --i;) c.push_back(v); + } + const int64_t items_processed = state.iterations() * state.range(0); + state.SetItemsProcessed(items_processed); + state.SetBytesProcessed(items_processed * static_cast(sizeof(v))); +} +BENCHMARK_TEMPLATE2(BM_Sequential, std::vector, int) + ->Range(1 << 0, 1 << 10); +BENCHMARK_TEMPLATE(BM_Sequential, std::list)->Range(1 << 0, 1 << 10); +// Test the variadic version of BENCHMARK_TEMPLATE in C++11 and beyond. +#ifdef BENCHMARK_HAS_CXX11 +BENCHMARK_TEMPLATE(BM_Sequential, std::vector, int)->Arg(512); +#endif + +static void BM_StringCompare(benchmark::State& state) { + size_t len = static_cast(state.range(0)); + std::string s1(len, '-'); + std::string s2(len, '-'); + for (auto _ : state) { + auto comp = s1.compare(s2); + benchmark::DoNotOptimize(comp); + } +} +BENCHMARK(BM_StringCompare)->Range(1, 1 << 20); + +static void BM_SetupTeardown(benchmark::State& state) { + if (state.thread_index() == 0) { + // No need to lock test_vector_mu here as this is running single-threaded. + test_vector = new std::vector(); + } + int i = 0; + for (auto _ : state) { + std::lock_guard l(test_vector_mu); + if (i % 2 == 0) + test_vector->push_back(i); + else + test_vector->pop_back(); + ++i; + } + if (state.thread_index() == 0) { + delete test_vector; + } +} +BENCHMARK(BM_SetupTeardown)->ThreadPerCpu(); + +static void BM_LongTest(benchmark::State& state) { + double tracker = 0.0; + for (auto _ : state) { + for (int i = 0; i < state.range(0); ++i) + benchmark::DoNotOptimize(tracker += i); + } +} +BENCHMARK(BM_LongTest)->Range(1 << 16, 1 << 28); + +static void BM_ParallelMemset(benchmark::State& state) { + int64_t size = state.range(0) / static_cast(sizeof(int)); + int thread_size = static_cast(size) / state.threads(); + int from = thread_size * state.thread_index(); + int to = from + thread_size; + + if (state.thread_index() == 0) { + test_vector = new std::vector(static_cast(size)); + } + + for (auto _ : state) { + for (int i = from; i < to; i++) { + // No need to lock test_vector_mu as ranges + // do not overlap between threads. + benchmark::DoNotOptimize(test_vector->at(static_cast(i)) = 1); + } + } + + if (state.thread_index() == 0) { + delete test_vector; + } +} +BENCHMARK(BM_ParallelMemset)->Arg(10 << 20)->ThreadRange(1, 4); + +static void BM_ManualTiming(benchmark::State& state) { + int64_t slept_for = 0; + int64_t microseconds = state.range(0); + std::chrono::duration sleep_duration{ + static_cast(microseconds)}; + + for (auto _ : state) { + auto start = std::chrono::high_resolution_clock::now(); + // Simulate some useful workload with a sleep + std::this_thread::sleep_for( + std::chrono::duration_cast(sleep_duration)); + auto end = std::chrono::high_resolution_clock::now(); + + auto elapsed = + std::chrono::duration_cast>(end - start); + + state.SetIterationTime(elapsed.count()); + slept_for += microseconds; + } + state.SetItemsProcessed(slept_for); +} +BENCHMARK(BM_ManualTiming)->Range(1, 1 << 14)->UseRealTime(); +BENCHMARK(BM_ManualTiming)->Range(1, 1 << 14)->UseManualTime(); + +#ifdef BENCHMARK_HAS_CXX11 + +template +void BM_with_args(benchmark::State& state, Args&&...) { + for (auto _ : state) { + } +} +BENCHMARK_CAPTURE(BM_with_args, int_test, 42, 43, 44); +BENCHMARK_CAPTURE(BM_with_args, string_and_pair_test, std::string("abc"), + std::pair(42, 3.8)); + +void BM_non_template_args(benchmark::State& state, int, double) { + while (state.KeepRunning()) { + } +} +BENCHMARK_CAPTURE(BM_non_template_args, basic_test, 0, 0); + +template +void BM_template2_capture(benchmark::State& state, ExtraArgs&&... extra_args) { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + unsigned int dummy[sizeof...(ExtraArgs)] = {extra_args...}; + assert(dummy[0] == 42); + for (auto _ : state) { + } +} +BENCHMARK_TEMPLATE2_CAPTURE(BM_template2_capture, void, char*, foo, 42U); +BENCHMARK_CAPTURE((BM_template2_capture), foo, 42U); + +template +void BM_template1_capture(benchmark::State& state, ExtraArgs&&... extra_args) { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + unsigned long dummy[sizeof...(ExtraArgs)] = {extra_args...}; + assert(dummy[0] == 24); + for (auto _ : state) { + } +} +BENCHMARK_TEMPLATE1_CAPTURE(BM_template1_capture, void, foo, 24UL); +BENCHMARK_CAPTURE(BM_template1_capture, foo, 24UL); + +#endif // BENCHMARK_HAS_CXX11 + +static void BM_DenseThreadRanges(benchmark::State& st) { + switch (st.range(0)) { + case 1: + assert(st.threads() == 1 || st.threads() == 2 || st.threads() == 3); + break; + case 2: + assert(st.threads() == 1 || st.threads() == 3 || st.threads() == 4); + break; + case 3: + assert(st.threads() == 5 || st.threads() == 8 || st.threads() == 11 || + st.threads() == 14); + break; + default: + assert(false && "Invalid test case number"); + } + while (st.KeepRunning()) { + } +} +BENCHMARK(BM_DenseThreadRanges)->Arg(1)->DenseThreadRange(1, 3); +BENCHMARK(BM_DenseThreadRanges)->Arg(2)->DenseThreadRange(1, 4, 2); +BENCHMARK(BM_DenseThreadRanges)->Arg(3)->DenseThreadRange(5, 14, 3); + +static void BM_BenchmarkName(benchmark::State& state) { + for (auto _ : state) { + } + + // Check that the benchmark name is passed correctly to `state`. + assert("BM_BenchmarkName" == state.name()); +} +BENCHMARK(BM_BenchmarkName); + +// regression test for #1446 +template +static void BM_templated_test(benchmark::State& state) { + for (auto _ : state) { + type created_string; + benchmark::DoNotOptimize(created_string); + } +} + +static auto BM_templated_test_double = BM_templated_test>; +BENCHMARK(BM_templated_test_double); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/clobber_memory_assembly_test.cc b/third_party/benchmark/test/clobber_memory_assembly_test.cc new file mode 100644 index 0000000..54e26cc --- /dev/null +++ b/third_party/benchmark/test/clobber_memory_assembly_test.cc @@ -0,0 +1,64 @@ +#include + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wreturn-type" +#endif +BENCHMARK_DISABLE_DEPRECATED_WARNING + +extern "C" { + +extern int ExternInt; +extern int ExternInt2; +extern int ExternInt3; +} + +// CHECK-LABEL: test_basic: +extern "C" void test_basic() { + int x; + benchmark::DoNotOptimize(&x); + x = 101; + benchmark::ClobberMemory(); + // CHECK: leaq [[DEST:[^,]+]], %rax + // CHECK: movl $101, [[DEST]] + // CHECK: ret +} + +// CHECK-LABEL: test_redundant_store: +extern "C" void test_redundant_store() { + ExternInt = 3; + benchmark::ClobberMemory(); + ExternInt = 51; + // CHECK-DAG: ExternInt + // CHECK-DAG: movl $3 + // CHECK: movl $51 +} + +// CHECK-LABEL: test_redundant_read: +extern "C" void test_redundant_read() { + int x; + benchmark::DoNotOptimize(&x); + x = ExternInt; + benchmark::ClobberMemory(); + x = ExternInt2; + // CHECK: leaq [[DEST:[^,]+]], %rax + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, [[DEST]] + // CHECK-NOT: ExternInt2 + // CHECK: ret +} + +// CHECK-LABEL: test_redundant_read2: +extern "C" void test_redundant_read2() { + int x; + benchmark::DoNotOptimize(&x); + x = ExternInt; + benchmark::ClobberMemory(); + x = ExternInt2; + benchmark::ClobberMemory(); + // CHECK: leaq [[DEST:[^,]+]], %rax + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, [[DEST]] + // CHECK: ExternInt2(%rip) + // CHECK: movl %eax, [[DEST]] + // CHECK: ret +} diff --git a/third_party/benchmark/test/commandlineflags_gtest.cc b/third_party/benchmark/test/commandlineflags_gtest.cc new file mode 100644 index 0000000..8412008 --- /dev/null +++ b/third_party/benchmark/test/commandlineflags_gtest.cc @@ -0,0 +1,228 @@ +#include + +#include "../src/commandlineflags.h" +#include "../src/internal_macros.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace benchmark { +namespace { + +#if defined(BENCHMARK_OS_WINDOWS) +int setenv(const char* name, const char* value, int overwrite) { + if (!overwrite) { + // NOTE: getenv_s is far superior but not available under mingw. + char* env_value = getenv(name); + if (env_value == nullptr) { + return -1; + } + } + return _putenv_s(name, value); +} + +int unsetenv(const char* name) { return _putenv_s(name, ""); } + +#endif // BENCHMARK_OS_WINDOWS + +TEST(BoolFromEnv, Default) { + ASSERT_EQ(unsetenv("NOT_IN_ENV"), 0); + EXPECT_EQ(BoolFromEnv("not_in_env", true), true); +} + +TEST(BoolFromEnv, False) { + ASSERT_EQ(setenv("IN_ENV", "0", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "N", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "n", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "NO", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "No", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "no", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "F", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "f", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "FALSE", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "False", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "false", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "OFF", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "Off", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "off", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", true), false); + unsetenv("IN_ENV"); +} + +TEST(BoolFromEnv, True) { + ASSERT_EQ(setenv("IN_ENV", "1", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "Y", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "y", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "YES", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "Yes", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "yes", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "T", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "t", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "TRUE", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "True", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "true", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "ON", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "On", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + + ASSERT_EQ(setenv("IN_ENV", "on", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); + +#ifndef BENCHMARK_OS_WINDOWS + ASSERT_EQ(setenv("IN_ENV", "", 1), 0); + EXPECT_EQ(BoolFromEnv("in_env", false), true); + unsetenv("IN_ENV"); +#endif +} + +TEST(Int32FromEnv, NotInEnv) { + ASSERT_EQ(unsetenv("NOT_IN_ENV"), 0); + EXPECT_EQ(Int32FromEnv("not_in_env", 42), 42); +} + +TEST(Int32FromEnv, InvalidInteger) { + ASSERT_EQ(setenv("IN_ENV", "foo", 1), 0); + EXPECT_EQ(Int32FromEnv("in_env", 42), 42); + unsetenv("IN_ENV"); +} + +TEST(Int32FromEnv, ValidInteger) { + ASSERT_EQ(setenv("IN_ENV", "42", 1), 0); + EXPECT_EQ(Int32FromEnv("in_env", 64), 42); + unsetenv("IN_ENV"); +} + +TEST(DoubleFromEnv, NotInEnv) { + ASSERT_EQ(unsetenv("NOT_IN_ENV"), 0); + EXPECT_EQ(DoubleFromEnv("not_in_env", 0.51), 0.51); +} + +TEST(DoubleFromEnv, InvalidReal) { + ASSERT_EQ(setenv("IN_ENV", "foo", 1), 0); + EXPECT_EQ(DoubleFromEnv("in_env", 0.51), 0.51); + unsetenv("IN_ENV"); +} + +TEST(DoubleFromEnv, ValidReal) { + ASSERT_EQ(setenv("IN_ENV", "0.51", 1), 0); + EXPECT_EQ(DoubleFromEnv("in_env", 0.71), 0.51); + unsetenv("IN_ENV"); +} + +TEST(StringFromEnv, Default) { + ASSERT_EQ(unsetenv("NOT_IN_ENV"), 0); + EXPECT_STREQ(StringFromEnv("not_in_env", "foo"), "foo"); +} + +TEST(StringFromEnv, Valid) { + ASSERT_EQ(setenv("IN_ENV", "foo", 1), 0); + EXPECT_STREQ(StringFromEnv("in_env", "bar"), "foo"); + unsetenv("IN_ENV"); +} + +TEST(KvPairsFromEnv, Default) { + ASSERT_EQ(unsetenv("NOT_IN_ENV"), 0); + EXPECT_THAT(KvPairsFromEnv("not_in_env", {{"foo", "bar"}}), + testing::ElementsAre(testing::Pair("foo", "bar"))); +} + +TEST(KvPairsFromEnv, MalformedReturnsDefault) { + ASSERT_EQ(setenv("IN_ENV", "foo", 1), 0); + EXPECT_THAT(KvPairsFromEnv("in_env", {{"foo", "bar"}}), + testing::ElementsAre(testing::Pair("foo", "bar"))); + unsetenv("IN_ENV"); +} + +TEST(KvPairsFromEnv, Single) { + ASSERT_EQ(setenv("IN_ENV", "foo=bar", 1), 0); + EXPECT_THAT(KvPairsFromEnv("in_env", {}), + testing::ElementsAre(testing::Pair("foo", "bar"))); + unsetenv("IN_ENV"); +} + +TEST(KvPairsFromEnv, Multiple) { + ASSERT_EQ(setenv("IN_ENV", "foo=bar,baz=qux", 1), 0); + EXPECT_THAT(KvPairsFromEnv("in_env", {}), + testing::UnorderedElementsAre(testing::Pair("foo", "bar"), + testing::Pair("baz", "qux"))); + unsetenv("IN_ENV"); +} + +} // namespace +} // namespace benchmark diff --git a/third_party/benchmark/test/complexity_test.cc b/third_party/benchmark/test/complexity_test.cc new file mode 100644 index 0000000..0729d15 --- /dev/null +++ b/third_party/benchmark/test/complexity_test.cc @@ -0,0 +1,270 @@ +#undef NDEBUG +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "output_test.h" + +namespace { + +#define ADD_COMPLEXITY_CASES(...) \ + int CONCAT(dummy, __LINE__) = AddComplexityTest(__VA_ARGS__) + +int AddComplexityTest(const std::string &test_name, + const std::string &big_o_test_name, + const std::string &rms_test_name, + const std::string &big_o, int family_index) { + SetSubstitutions({{"%name", test_name}, + {"%bigo_name", big_o_test_name}, + {"%rms_name", rms_test_name}, + {"%bigo_str", "[ ]* %float " + big_o}, + {"%bigo", big_o}, + {"%rms", "[ ]*[0-9]+ %"}}); + AddCases( + TC_ConsoleOut, + {{"^%bigo_name %bigo_str %bigo_str[ ]*$"}, + {"^%bigo_name", MR_Not}, // Assert we we didn't only matched a name. + {"^%rms_name %rms %rms[ ]*$", MR_Next}}); + AddCases( + TC_JSONOut, + {{"\"name\": \"%bigo_name\",$"}, + {"\"family_index\": " + std::to_string(family_index) + ",$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"%name\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": %int,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"BigO\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"cpu_coefficient\": %float,$", MR_Next}, + {"\"real_coefficient\": %float,$", MR_Next}, + {"\"big_o\": \"%bigo\",$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}, + {"\"name\": \"%rms_name\",$"}, + {"\"family_index\": " + std::to_string(family_index) + ",$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"%name\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": %int,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"RMS\",$", MR_Next}, + {"\"aggregate_unit\": \"percentage\",$", MR_Next}, + {"\"rms\": %float$", MR_Next}, + {"}", MR_Next}}); + AddCases(TC_CSVOut, {{"^\"%bigo_name\",,%float,%float,%bigo,,,,,$"}, + {"^\"%bigo_name\"", MR_Not}, + {"^\"%rms_name\",,%float,%float,,,,,,$", MR_Next}}); + return 0; +} + +} // end namespace + +// ========================================================================= // +// --------------------------- Testing BigO O(1) --------------------------- // +// ========================================================================= // + +void BM_Complexity_O1(benchmark::State &state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + benchmark::DoNotOptimize(state.iterations()); + double tmp = static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + for (benchmark::IterationCount i = 0; i < state.iterations(); ++i) { + benchmark::DoNotOptimize(state.iterations()); + tmp *= static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + } + + // always 1ns per iteration + state.SetIterationTime(42 * 1e-9); + } + state.SetComplexityN(state.range(0)); +} +BENCHMARK(BM_Complexity_O1) + ->Range(1, 1 << 18) + ->UseManualTime() + ->Complexity(benchmark::o1); +BENCHMARK(BM_Complexity_O1)->Range(1, 1 << 18)->UseManualTime()->Complexity(); +BENCHMARK(BM_Complexity_O1) + ->Range(1, 1 << 18) + ->UseManualTime() + ->Complexity([](benchmark::IterationCount) { return 1.0; }); + +const char *one_test_name = "BM_Complexity_O1/manual_time"; +const char *big_o_1_test_name = "BM_Complexity_O1/manual_time_BigO"; +const char *rms_o_1_test_name = "BM_Complexity_O1/manual_time_RMS"; +const char *enum_auto_big_o_1 = "\\([0-9]+\\)"; +const char *lambda_big_o_1 = "f\\(N\\)"; + +// Add enum tests +ADD_COMPLEXITY_CASES(one_test_name, big_o_1_test_name, rms_o_1_test_name, + enum_auto_big_o_1, /*family_index=*/0); + +// Add auto tests +ADD_COMPLEXITY_CASES(one_test_name, big_o_1_test_name, rms_o_1_test_name, + enum_auto_big_o_1, /*family_index=*/1); + +// Add lambda tests +ADD_COMPLEXITY_CASES(one_test_name, big_o_1_test_name, rms_o_1_test_name, + lambda_big_o_1, /*family_index=*/2); + +// ========================================================================= // +// --------------------------- Testing BigO O(N) --------------------------- // +// ========================================================================= // + +void BM_Complexity_O_N(benchmark::State &state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + benchmark::DoNotOptimize(state.iterations()); + double tmp = static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + for (benchmark::IterationCount i = 0; i < state.iterations(); ++i) { + benchmark::DoNotOptimize(state.iterations()); + tmp *= static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + } + + // 1ns per iteration per entry + state.SetIterationTime(static_cast(state.range(0)) * 42 * 1e-9); + } + state.SetComplexityN(state.range(0)); +} +BENCHMARK(BM_Complexity_O_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1 << 20) + ->UseManualTime() + ->Complexity(benchmark::oN); +BENCHMARK(BM_Complexity_O_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1 << 20) + ->UseManualTime() + ->Complexity(); +BENCHMARK(BM_Complexity_O_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1 << 20) + ->UseManualTime() + ->Complexity([](benchmark::IterationCount n) -> double { + return static_cast(n); + }); + +const char *n_test_name = "BM_Complexity_O_N/manual_time"; +const char *big_o_n_test_name = "BM_Complexity_O_N/manual_time_BigO"; +const char *rms_o_n_test_name = "BM_Complexity_O_N/manual_time_RMS"; +const char *enum_auto_big_o_n = "N"; +const char *lambda_big_o_n = "f\\(N\\)"; + +// Add enum tests +ADD_COMPLEXITY_CASES(n_test_name, big_o_n_test_name, rms_o_n_test_name, + enum_auto_big_o_n, /*family_index=*/3); + +// Add auto tests +ADD_COMPLEXITY_CASES(n_test_name, big_o_n_test_name, rms_o_n_test_name, + enum_auto_big_o_n, /*family_index=*/4); + +// Add lambda tests +ADD_COMPLEXITY_CASES(n_test_name, big_o_n_test_name, rms_o_n_test_name, + lambda_big_o_n, /*family_index=*/5); + +// ========================================================================= // +// ------------------------- Testing BigO O(NlgN) ------------------------- // +// ========================================================================= // + +static const double kLog2E = 1.44269504088896340736; +static void BM_Complexity_O_N_log_N(benchmark::State &state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + benchmark::DoNotOptimize(state.iterations()); + double tmp = static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + for (benchmark::IterationCount i = 0; i < state.iterations(); ++i) { + benchmark::DoNotOptimize(state.iterations()); + tmp *= static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + } + + state.SetIterationTime(static_cast(state.range(0)) * kLog2E * + std::log(state.range(0)) * 42 * 1e-9); + } + state.SetComplexityN(state.range(0)); +} +BENCHMARK(BM_Complexity_O_N_log_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1U << 24) + ->UseManualTime() + ->Complexity(benchmark::oNLogN); +BENCHMARK(BM_Complexity_O_N_log_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1U << 24) + ->UseManualTime() + ->Complexity(); +BENCHMARK(BM_Complexity_O_N_log_N) + ->RangeMultiplier(2) + ->Range(1 << 10, 1U << 24) + ->UseManualTime() + ->Complexity([](benchmark::IterationCount n) { + return kLog2E * static_cast(n) * std::log(static_cast(n)); + }); + +const char *n_lg_n_test_name = "BM_Complexity_O_N_log_N/manual_time"; +const char *big_o_n_lg_n_test_name = "BM_Complexity_O_N_log_N/manual_time_BigO"; +const char *rms_o_n_lg_n_test_name = "BM_Complexity_O_N_log_N/manual_time_RMS"; +const char *enum_auto_big_o_n_lg_n = "NlgN"; +const char *lambda_big_o_n_lg_n = "f\\(N\\)"; + +// Add enum tests +ADD_COMPLEXITY_CASES(n_lg_n_test_name, big_o_n_lg_n_test_name, + rms_o_n_lg_n_test_name, enum_auto_big_o_n_lg_n, + /*family_index=*/6); + +// NOTE: auto big-o is wron.g +ADD_COMPLEXITY_CASES(n_lg_n_test_name, big_o_n_lg_n_test_name, + rms_o_n_lg_n_test_name, enum_auto_big_o_n_lg_n, + /*family_index=*/7); + +//// Add lambda tests +ADD_COMPLEXITY_CASES(n_lg_n_test_name, big_o_n_lg_n_test_name, + rms_o_n_lg_n_test_name, lambda_big_o_n_lg_n, + /*family_index=*/8); + +// ========================================================================= // +// -------- Testing formatting of Complexity with captured args ------------ // +// ========================================================================= // + +void BM_ComplexityCaptureArgs(benchmark::State &state, int n) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + benchmark::DoNotOptimize(state.iterations()); + double tmp = static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + for (benchmark::IterationCount i = 0; i < state.iterations(); ++i) { + benchmark::DoNotOptimize(state.iterations()); + tmp *= static_cast(state.iterations()); + benchmark::DoNotOptimize(tmp); + } + + state.SetIterationTime(static_cast(state.range(0)) * 42 * 1e-9); + } + state.SetComplexityN(n); +} + +BENCHMARK_CAPTURE(BM_ComplexityCaptureArgs, capture_test, 100) + ->UseManualTime() + ->Complexity(benchmark::oN) + ->Ranges({{1, 2}, {3, 4}}); + +const std::string complexity_capture_name = + "BM_ComplexityCaptureArgs/capture_test/manual_time"; + +ADD_COMPLEXITY_CASES(complexity_capture_name, complexity_capture_name + "_BigO", + complexity_capture_name + "_RMS", "N", + /*family_index=*/9); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char *argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/cxx03_test.cc b/third_party/benchmark/test/cxx03_test.cc new file mode 100644 index 0000000..9711c1b --- /dev/null +++ b/third_party/benchmark/test/cxx03_test.cc @@ -0,0 +1,62 @@ +#undef NDEBUG +#include +#include + +#include "benchmark/benchmark.h" + +#if __cplusplus >= 201103L +#error C++11 or greater detected. Should be C++03. +#endif + +#ifdef BENCHMARK_HAS_CXX11 +#error C++11 or greater detected by the library. BENCHMARK_HAS_CXX11 is defined. +#endif + +void BM_empty(benchmark::State& state) { + while (state.KeepRunning()) { + volatile benchmark::IterationCount x = state.iterations(); + ((void)x); + } +} +BENCHMARK(BM_empty); + +// The new C++11 interface for args/ranges requires initializer list support. +// Therefore we provide the old interface to support C++03. +void BM_old_arg_range_interface(benchmark::State& state) { + assert((state.range(0) == 1 && state.range(1) == 2) || + (state.range(0) == 5 && state.range(1) == 6)); + while (state.KeepRunning()) { + } +} +BENCHMARK(BM_old_arg_range_interface)->ArgPair(1, 2)->RangePair(5, 5, 6, 6); + +template +void BM_template2(benchmark::State& state) { + BM_empty(state); +} +BENCHMARK_TEMPLATE2(BM_template2, int, long); + +template +void BM_template1(benchmark::State& state) { + BM_empty(state); +} +BENCHMARK_TEMPLATE(BM_template1, long); +BENCHMARK_TEMPLATE1(BM_template1, int); + +template +struct BM_Fixture : public ::benchmark::Fixture {}; + +BENCHMARK_TEMPLATE_F(BM_Fixture, BM_template1, long)(benchmark::State& state) { + BM_empty(state); +} +BENCHMARK_TEMPLATE1_F(BM_Fixture, BM_template2, int)(benchmark::State& state) { + BM_empty(state); +} + +void BM_counters(benchmark::State& state) { + BM_empty(state); + state.counters["Foo"] = 2; +} +BENCHMARK(BM_counters); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/diagnostics_test.cc b/third_party/benchmark/test/diagnostics_test.cc new file mode 100644 index 0000000..7c68a98 --- /dev/null +++ b/third_party/benchmark/test/diagnostics_test.cc @@ -0,0 +1,91 @@ +// Testing: +// State::PauseTiming() +// State::ResumeTiming() +// Test that CHECK's within these function diagnose when they are called +// outside of the KeepRunning() loop. +// +// NOTE: Users should NOT include or use src/check.h. This is only done in +// order to test library internals. + +#include +#include + +#include "../src/check.h" +#include "benchmark/benchmark.h" + +#if defined(__GNUC__) && !defined(__EXCEPTIONS) +#define TEST_HAS_NO_EXCEPTIONS +#endif + +void TestHandler() { +#ifndef TEST_HAS_NO_EXCEPTIONS + throw std::logic_error(""); +#else + std::abort(); +#endif +} + +void try_invalid_pause_resume(benchmark::State& state) { +#if !defined(TEST_BENCHMARK_LIBRARY_HAS_NO_ASSERTIONS) && \ + !defined(TEST_HAS_NO_EXCEPTIONS) + try { + state.PauseTiming(); + std::abort(); + } catch (std::logic_error const&) { + } + try { + state.ResumeTiming(); + std::abort(); + } catch (std::logic_error const&) { + } +#else + (void)state; // avoid unused warning +#endif +} + +void BM_diagnostic_test(benchmark::State& state) { + static bool called_once = false; + + if (called_once == false) try_invalid_pause_resume(state); + + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + + if (called_once == false) try_invalid_pause_resume(state); + + called_once = true; +} +BENCHMARK(BM_diagnostic_test); + +void BM_diagnostic_test_keep_running(benchmark::State& state) { + static bool called_once = false; + + if (called_once == false) try_invalid_pause_resume(state); + + while (state.KeepRunning()) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + + if (called_once == false) try_invalid_pause_resume(state); + + called_once = true; +} +BENCHMARK(BM_diagnostic_test_keep_running); + +int main(int argc, char* argv[]) { +#ifdef NDEBUG + // This test is exercising functionality for debug builds, which are not + // available in release builds. Skip the test if we are in that environment + // to avoid a test failure. + std::cout << "Diagnostic test disabled in release build" << std::endl; + (void)argc; + (void)argv; +#else + benchmark::internal::GetAbortHandler() = &TestHandler; + benchmark::Initialize(&argc, argv); + benchmark::RunSpecifiedBenchmarks(); +#endif +} diff --git a/third_party/benchmark/test/display_aggregates_only_test.cc b/third_party/benchmark/test/display_aggregates_only_test.cc new file mode 100644 index 0000000..6ad65e7 --- /dev/null +++ b/third_party/benchmark/test/display_aggregates_only_test.cc @@ -0,0 +1,45 @@ + +#undef NDEBUG +#include +#include + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// Ok this test is super ugly. We want to check what happens with the file +// reporter in the presence of DisplayAggregatesOnly(). +// We do not care about console output, the normal tests check that already. + +void BM_SummaryRepeat(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_SummaryRepeat)->Repetitions(3)->DisplayAggregatesOnly(); + +int main(int argc, char* argv[]) { + const std::string output = GetFileReporterOutput(argc, argv); + + if (SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3") != 7 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3\"") != 3 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_mean\"") != 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_median\"") != + 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_stddev\"") != + 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_cv\"") != 1) { + std::cout << "Precondition mismatch. Expected to only find 8 " + "occurrences of \"BM_SummaryRepeat/repeats:3\" substring:\n" + "\"name\": \"BM_SummaryRepeat/repeats:3\", " + "\"name\": \"BM_SummaryRepeat/repeats:3\", " + "\"name\": \"BM_SummaryRepeat/repeats:3\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_mean\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_median\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_stddev\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_cv\"\nThe entire " + "output:\n"; + std::cout << output; + return 1; + } + + return 0; +} diff --git a/third_party/benchmark/test/donotoptimize_assembly_test.cc b/third_party/benchmark/test/donotoptimize_assembly_test.cc new file mode 100644 index 0000000..dc286f5 --- /dev/null +++ b/third_party/benchmark/test/donotoptimize_assembly_test.cc @@ -0,0 +1,201 @@ +#include + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wreturn-type" +#endif +BENCHMARK_DISABLE_DEPRECATED_WARNING + +extern "C" { + +extern int ExternInt; +extern int ExternInt2; +extern int ExternInt3; +extern int BigArray[2049]; + +const int ConstBigArray[2049]{}; + +inline int Add42(int x) { return x + 42; } + +struct NotTriviallyCopyable { + NotTriviallyCopyable(); + explicit NotTriviallyCopyable(int x) : value(x) {} + NotTriviallyCopyable(NotTriviallyCopyable const &); + int value; +}; + +struct Large { + int value; + int data[2]; +}; + +struct ExtraLarge { + int arr[2049]; +}; +} + +extern ExtraLarge ExtraLargeObj; +const ExtraLarge ConstExtraLargeObj{}; + +// CHECK-LABEL: test_with_rvalue: +extern "C" void test_with_rvalue() { + benchmark::DoNotOptimize(Add42(0)); + // CHECK: movl $42, %eax + // CHECK: ret +} + +// CHECK-LABEL: test_with_large_rvalue: +extern "C" void test_with_large_rvalue() { + benchmark::DoNotOptimize(Large{ExternInt, {ExternInt, ExternInt}}); + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG:[a-z]+]] + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: ret +} + +// CHECK-LABEL: test_with_non_trivial_rvalue: +extern "C" void test_with_non_trivial_rvalue() { + benchmark::DoNotOptimize(NotTriviallyCopyable(ExternInt)); + // CHECK: mov{{l|q}} ExternInt(%rip) + // CHECK: ret +} + +// CHECK-LABEL: test_with_lvalue: +extern "C" void test_with_lvalue() { + int x = 101; + benchmark::DoNotOptimize(x); + // CHECK-GNU: movl $101, %eax + // CHECK-CLANG: movl $101, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: ret +} + +// CHECK-LABEL: test_with_large_lvalue: +extern "C" void test_with_large_lvalue() { + Large L{ExternInt, {ExternInt, ExternInt}}; + benchmark::DoNotOptimize(L); + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: ret +} + +// CHECK-LABEL: test_with_extra_large_lvalue_with_op: +extern "C" void test_with_extra_large_lvalue_with_op() { + ExtraLargeObj.arr[16] = 42; + benchmark::DoNotOptimize(ExtraLargeObj); + // CHECK: movl $42, ExtraLargeObj+64(%rip) + // CHECK: ret +} + +// CHECK-LABEL: test_with_big_array_with_op +extern "C" void test_with_big_array_with_op() { + BigArray[16] = 42; + benchmark::DoNotOptimize(BigArray); + // CHECK: movl $42, BigArray+64(%rip) + // CHECK: ret +} + +// CHECK-LABEL: test_with_non_trivial_lvalue: +extern "C" void test_with_non_trivial_lvalue() { + NotTriviallyCopyable NTC(ExternInt); + benchmark::DoNotOptimize(NTC); + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: ret +} + +// CHECK-LABEL: test_with_const_lvalue: +extern "C" void test_with_const_lvalue() { + const int x = 123; + benchmark::DoNotOptimize(x); + // CHECK: movl $123, %eax + // CHECK: ret +} + +// CHECK-LABEL: test_with_large_const_lvalue: +extern "C" void test_with_large_const_lvalue() { + const Large L{ExternInt, {ExternInt, ExternInt}}; + benchmark::DoNotOptimize(L); + // CHECK: ExternInt(%rip) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: movl %eax, -{{[0-9]+}}(%[[REG]]) + // CHECK: ret +} + +// CHECK-LABEL: test_with_const_extra_large_obj: +extern "C" void test_with_const_extra_large_obj() { + benchmark::DoNotOptimize(ConstExtraLargeObj); + // CHECK: ret +} + +// CHECK-LABEL: test_with_const_big_array +extern "C" void test_with_const_big_array() { + benchmark::DoNotOptimize(ConstBigArray); + // CHECK: ret +} + +// CHECK-LABEL: test_with_non_trivial_const_lvalue: +extern "C" void test_with_non_trivial_const_lvalue() { + const NotTriviallyCopyable Obj(ExternInt); + benchmark::DoNotOptimize(Obj); + // CHECK: mov{{q|l}} ExternInt(%rip) + // CHECK: ret +} + +// CHECK-LABEL: test_div_by_two: +extern "C" int test_div_by_two(int input) { + int divisor = 2; + benchmark::DoNotOptimize(divisor); + return input / divisor; + // CHECK: movl $2, [[DEST:.*]] + // CHECK: idivl [[DEST]] + // CHECK: ret +} + +// CHECK-LABEL: test_inc_integer: +extern "C" int test_inc_integer() { + int x = 0; + for (int i = 0; i < 5; ++i) benchmark::DoNotOptimize(++x); + // CHECK: movl $1, [[DEST:.*]] + // CHECK: {{(addl \$1,|incl)}} [[DEST]] + // CHECK: {{(addl \$1,|incl)}} [[DEST]] + // CHECK: {{(addl \$1,|incl)}} [[DEST]] + // CHECK: {{(addl \$1,|incl)}} [[DEST]] + // CHECK-CLANG: movl [[DEST]], %eax + // CHECK: ret + return x; +} + +// CHECK-LABEL: test_pointer_rvalue +extern "C" void test_pointer_rvalue() { + // CHECK: movl $42, [[DEST:.*]] + // CHECK: leaq [[DEST]], %rax + // CHECK-CLANG: movq %rax, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: ret + int x = 42; + benchmark::DoNotOptimize(&x); +} + +// CHECK-LABEL: test_pointer_const_lvalue: +extern "C" void test_pointer_const_lvalue() { + // CHECK: movl $42, [[DEST:.*]] + // CHECK: leaq [[DEST]], %rax + // CHECK-CLANG: movq %rax, -{{[0-9]+}}(%[[REG:[a-z]+]]) + // CHECK: ret + int x = 42; + int *const xp = &x; + benchmark::DoNotOptimize(xp); +} + +// CHECK-LABEL: test_pointer_lvalue: +extern "C" void test_pointer_lvalue() { + // CHECK: movl $42, [[DEST:.*]] + // CHECK: leaq [[DEST]], %rax + // CHECK-CLANG: movq %rax, -{{[0-9]+}}(%[[REG:[a-z+]+]]) + // CHECK: ret + int x = 42; + int *xp = &x; + benchmark::DoNotOptimize(xp); +} diff --git a/third_party/benchmark/test/donotoptimize_test.cc b/third_party/benchmark/test/donotoptimize_test.cc new file mode 100644 index 0000000..04ec938 --- /dev/null +++ b/third_party/benchmark/test/donotoptimize_test.cc @@ -0,0 +1,69 @@ +#include + +#include "benchmark/benchmark.h" + +namespace { +#if defined(__GNUC__) +std::int64_t double_up(const std::int64_t x) __attribute__((const)); +#endif +std::int64_t double_up(const std::int64_t x) { return x * 2; } +} // namespace + +// Using DoNotOptimize on types like BitRef seem to cause a lot of problems +// with the inline assembly on both GCC and Clang. +struct BitRef { + int index; + unsigned char& byte; + + public: + static BitRef Make() { + static unsigned char arr[2] = {}; + BitRef b(1, arr[0]); + return b; + } + + private: + BitRef(int i, unsigned char& b) : index(i), byte(b) {} +}; + +int main(int, char*[]) { + // this test verifies compilation of DoNotOptimize() for some types + + char buffer1[1] = ""; + benchmark::DoNotOptimize(buffer1); + + char buffer2[2] = ""; + benchmark::DoNotOptimize(buffer2); + + char buffer3[3] = ""; + benchmark::DoNotOptimize(buffer3); + + char buffer8[8] = ""; + benchmark::DoNotOptimize(buffer8); + + char buffer20[20] = ""; + benchmark::DoNotOptimize(buffer20); + + char buffer1024[1024] = ""; + benchmark::DoNotOptimize(buffer1024); + char* bptr = &buffer1024[0]; + benchmark::DoNotOptimize(bptr); + + int x = 123; + benchmark::DoNotOptimize(x); + int* xp = &x; + benchmark::DoNotOptimize(xp); + benchmark::DoNotOptimize(x += 42); + + std::int64_t y = double_up(x); + benchmark::DoNotOptimize(y); + + // These tests are to e + BitRef lval = BitRef::Make(); + benchmark::DoNotOptimize(lval); + +#ifdef BENCHMARK_HAS_CXX11 + // Check that accept rvalue. + benchmark::DoNotOptimize(BitRef::Make()); +#endif +} diff --git a/third_party/benchmark/test/filter_test.cc b/third_party/benchmark/test/filter_test.cc new file mode 100644 index 0000000..4c8b8ea --- /dev/null +++ b/third_party/benchmark/test/filter_test.cc @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +namespace { + +class TestReporter : public benchmark::ConsoleReporter { + public: + bool ReportContext(const Context& context) override { + return ConsoleReporter::ReportContext(context); + }; + + void ReportRuns(const std::vector& report) override { + ++count_; + max_family_index_ = std::max(max_family_index_, report[0].family_index); + ConsoleReporter::ReportRuns(report); + }; + + TestReporter() : count_(0), max_family_index_(0) {} + + ~TestReporter() override {} + + int GetCount() const { return count_; } + + int64_t GetMaxFamilyIndex() const { return max_family_index_; } + + private: + mutable int count_; + mutable int64_t max_family_index_; +}; + +} // end namespace + +static void NoPrefix(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(NoPrefix); + +static void BM_Foo(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_Foo); + +static void BM_Bar(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_Bar); + +static void BM_FooBar(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_FooBar); + +static void BM_FooBa(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_FooBa); + +int main(int argc, char** argv) { + bool list_only = false; + for (int i = 0; i < argc; ++i) + list_only |= std::string(argv[i]).find("--benchmark_list_tests") != + std::string::npos; + + benchmark::Initialize(&argc, argv); + + TestReporter test_reporter; + const int64_t returned_count = + static_cast(benchmark::RunSpecifiedBenchmarks(&test_reporter)); + + if (argc == 2) { + // Make sure we ran all of the tests + std::stringstream ss(argv[1]); + int64_t expected_return; + ss >> expected_return; + + if (returned_count != expected_return) { + std::cerr << "ERROR: Expected " << expected_return + << " tests to match the filter but returned_count = " + << returned_count << std::endl; + return -1; + } + + const int64_t expected_reports = list_only ? 0 : expected_return; + const int64_t reports_count = test_reporter.GetCount(); + if (reports_count != expected_reports) { + std::cerr << "ERROR: Expected " << expected_reports + << " tests to be run but reported_count = " << reports_count + << std::endl; + return -1; + } + + const int64_t max_family_index = test_reporter.GetMaxFamilyIndex(); + const int64_t num_families = reports_count == 0 ? 0 : 1 + max_family_index; + if (num_families != expected_reports) { + std::cerr << "ERROR: Expected " << expected_reports + << " test families to be run but num_families = " + << num_families << std::endl; + return -1; + } + } + + return 0; +} diff --git a/third_party/benchmark/test/fixture_test.cc b/third_party/benchmark/test/fixture_test.cc new file mode 100644 index 0000000..d1093eb --- /dev/null +++ b/third_party/benchmark/test/fixture_test.cc @@ -0,0 +1,51 @@ + +#include +#include + +#include "benchmark/benchmark.h" + +#define FIXTURE_BECHMARK_NAME MyFixture + +class FIXTURE_BECHMARK_NAME : public ::benchmark::Fixture { + public: + void SetUp(const ::benchmark::State& state) override { + if (state.thread_index() == 0) { + assert(data.get() == nullptr); + data.reset(new int(42)); + } + } + + void TearDown(const ::benchmark::State& state) override { + if (state.thread_index() == 0) { + assert(data.get() != nullptr); + data.reset(); + } + } + + ~FIXTURE_BECHMARK_NAME() override { assert(data == nullptr); } + + std::unique_ptr data; +}; + +BENCHMARK_F(FIXTURE_BECHMARK_NAME, Foo)(benchmark::State& st) { + assert(data.get() != nullptr); + assert(*data == 42); + for (auto _ : st) { + } +} + +BENCHMARK_DEFINE_F(FIXTURE_BECHMARK_NAME, Bar)(benchmark::State& st) { + if (st.thread_index() == 0) { + assert(data.get() != nullptr); + assert(*data == 42); + } + for (auto _ : st) { + assert(data.get() != nullptr); + assert(*data == 42); + } + st.SetItemsProcessed(st.range(0)); +} +BENCHMARK_REGISTER_F(FIXTURE_BECHMARK_NAME, Bar)->Arg(42); +BENCHMARK_REGISTER_F(FIXTURE_BECHMARK_NAME, Bar)->Arg(42)->ThreadPerCpu(); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/internal_threading_test.cc b/third_party/benchmark/test/internal_threading_test.cc new file mode 100644 index 0000000..62b5b95 --- /dev/null +++ b/third_party/benchmark/test/internal_threading_test.cc @@ -0,0 +1,185 @@ + +#undef NDEBUG + +#include +#include + +#include "../src/timers.h" +#include "benchmark/benchmark.h" +#include "output_test.h" + +static const std::chrono::duration time_frame(50); +static const double time_frame_in_sec( + std::chrono::duration_cast>>( + time_frame) + .count()); + +void MyBusySpinwait() { + const auto start = benchmark::ChronoClockNow(); + + while (true) { + const auto now = benchmark::ChronoClockNow(); + const auto elapsed = now - start; + + if (std::chrono::duration(elapsed) >= + time_frame) + return; + } +} + +// ========================================================================= // +// --------------------------- TEST CASES BEGIN ---------------------------- // +// ========================================================================= // + +// ========================================================================= // +// BM_MainThread + +void BM_MainThread(benchmark::State& state) { + for (auto _ : state) { + MyBusySpinwait(); + state.SetIterationTime(time_frame_in_sec); + } + state.counters["invtime"] = + benchmark::Counter{1, benchmark::Counter::kIsRate}; +} + +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(1); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(1)->UseRealTime(); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(1)->UseManualTime(); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(1)->MeasureProcessCPUTime(); +BENCHMARK(BM_MainThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_MainThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(2); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(2)->UseRealTime(); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(2)->UseManualTime(); +BENCHMARK(BM_MainThread)->Iterations(1)->Threads(2)->MeasureProcessCPUTime(); +BENCHMARK(BM_MainThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_MainThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +// ========================================================================= // +// BM_WorkerThread + +void BM_WorkerThread(benchmark::State& state) { + for (auto _ : state) { + std::thread Worker(&MyBusySpinwait); + Worker.join(); + state.SetIterationTime(time_frame_in_sec); + } + state.counters["invtime"] = + benchmark::Counter{1, benchmark::Counter::kIsRate}; +} + +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(1); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(1)->UseRealTime(); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(1)->UseManualTime(); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(1)->MeasureProcessCPUTime(); +BENCHMARK(BM_WorkerThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_WorkerThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(2); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(2)->UseRealTime(); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(2)->UseManualTime(); +BENCHMARK(BM_WorkerThread)->Iterations(1)->Threads(2)->MeasureProcessCPUTime(); +BENCHMARK(BM_WorkerThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_WorkerThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +// ========================================================================= // +// BM_MainThreadAndWorkerThread + +void BM_MainThreadAndWorkerThread(benchmark::State& state) { + for (auto _ : state) { + std::thread Worker(&MyBusySpinwait); + MyBusySpinwait(); + Worker.join(); + state.SetIterationTime(time_frame_in_sec); + } + state.counters["invtime"] = + benchmark::Counter{1, benchmark::Counter::kIsRate}; +} + +BENCHMARK(BM_MainThreadAndWorkerThread)->Iterations(1)->Threads(1); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(1) + ->UseRealTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(1) + ->UseManualTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(1) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +BENCHMARK(BM_MainThreadAndWorkerThread)->Iterations(1)->Threads(2); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(2) + ->UseRealTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(2) + ->UseManualTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseRealTime(); +BENCHMARK(BM_MainThreadAndWorkerThread) + ->Iterations(1) + ->Threads(2) + ->MeasureProcessCPUTime() + ->UseManualTime(); + +// ========================================================================= // +// ---------------------------- TEST CASES END ----------------------------- // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/link_main_test.cc b/third_party/benchmark/test/link_main_test.cc new file mode 100644 index 0000000..131937e --- /dev/null +++ b/third_party/benchmark/test/link_main_test.cc @@ -0,0 +1,9 @@ +#include "benchmark/benchmark.h" + +void BM_empty(benchmark::State& state) { + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_empty); diff --git a/third_party/benchmark/test/map_test.cc b/third_party/benchmark/test/map_test.cc new file mode 100644 index 0000000..0fdba7c --- /dev/null +++ b/third_party/benchmark/test/map_test.cc @@ -0,0 +1,59 @@ +#include +#include + +#include "benchmark/benchmark.h" + +namespace { + +std::map ConstructRandomMap(int size) { + std::map m; + for (int i = 0; i < size; ++i) { + m.insert(std::make_pair(std::rand() % size, std::rand() % size)); + } + return m; +} + +} // namespace + +// Basic version. +static void BM_MapLookup(benchmark::State& state) { + const int size = static_cast(state.range(0)); + std::map m; + for (auto _ : state) { + state.PauseTiming(); + m = ConstructRandomMap(size); + state.ResumeTiming(); + for (int i = 0; i < size; ++i) { + auto it = m.find(std::rand() % size); + benchmark::DoNotOptimize(it); + } + } + state.SetItemsProcessed(state.iterations() * size); +} +BENCHMARK(BM_MapLookup)->Range(1 << 3, 1 << 12); + +// Using fixtures. +class MapFixture : public ::benchmark::Fixture { + public: + void SetUp(const ::benchmark::State& st) override { + m = ConstructRandomMap(static_cast(st.range(0))); + } + + void TearDown(const ::benchmark::State&) override { m.clear(); } + + std::map m; +}; + +BENCHMARK_DEFINE_F(MapFixture, Lookup)(benchmark::State& state) { + const int size = static_cast(state.range(0)); + for (auto _ : state) { + for (int i = 0; i < size; ++i) { + auto it = m.find(std::rand() % size); + benchmark::DoNotOptimize(it); + } + } + state.SetItemsProcessed(state.iterations() * size); +} +BENCHMARK_REGISTER_F(MapFixture, Lookup)->Range(1 << 3, 1 << 12); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/memory_manager_test.cc b/third_party/benchmark/test/memory_manager_test.cc new file mode 100644 index 0000000..4df674d --- /dev/null +++ b/third_party/benchmark/test/memory_manager_test.cc @@ -0,0 +1,47 @@ +#include + +#include "../src/check.h" +#include "benchmark/benchmark.h" +#include "output_test.h" + +class TestMemoryManager : public benchmark::MemoryManager { + void Start() override {} + void Stop(Result& result) override { + result.num_allocs = 42; + result.max_bytes_used = 42000; + } +}; + +void BM_empty(benchmark::State& state) { + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_empty); + +ADD_CASES(TC_ConsoleOut, {{"^BM_empty %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_empty\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_empty\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"allocs_per_iter\": %float,$", MR_Next}, + {"\"max_bytes_used\": 42000$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_empty\",%csv_report$"}}); + +int main(int argc, char* argv[]) { + std::unique_ptr mm(new TestMemoryManager()); + + benchmark::RegisterMemoryManager(mm.get()); + RunOutputTests(argc, argv); + benchmark::RegisterMemoryManager(nullptr); +} diff --git a/third_party/benchmark/test/min_time_parse_gtest.cc b/third_party/benchmark/test/min_time_parse_gtest.cc new file mode 100644 index 0000000..e2bdf67 --- /dev/null +++ b/third_party/benchmark/test/min_time_parse_gtest.cc @@ -0,0 +1,30 @@ +#include "../src/benchmark_runner.h" +#include "gtest/gtest.h" + +namespace { + +TEST(ParseMinTimeTest, InvalidInput) { +#if GTEST_HAS_DEATH_TEST + // Tests only runnable in debug mode (when BM_CHECK is enabled). +#ifndef NDEBUG +#ifndef TEST_BENCHMARK_LIBRARY_HAS_NO_ASSERTIONS + ASSERT_DEATH_IF_SUPPORTED( + { benchmark::internal::ParseBenchMinTime("abc"); }, + "Malformed seconds value passed to --benchmark_min_time: `abc`"); + + ASSERT_DEATH_IF_SUPPORTED( + { benchmark::internal::ParseBenchMinTime("123ms"); }, + "Malformed seconds value passed to --benchmark_min_time: `123ms`"); + + ASSERT_DEATH_IF_SUPPORTED( + { benchmark::internal::ParseBenchMinTime("1z"); }, + "Malformed seconds value passed to --benchmark_min_time: `1z`"); + + ASSERT_DEATH_IF_SUPPORTED( + { benchmark::internal::ParseBenchMinTime("1hs"); }, + "Malformed seconds value passed to --benchmark_min_time: `1hs`"); +#endif +#endif +#endif +} +} // namespace diff --git a/third_party/benchmark/test/multiple_ranges_test.cc b/third_party/benchmark/test/multiple_ranges_test.cc new file mode 100644 index 0000000..5300a96 --- /dev/null +++ b/third_party/benchmark/test/multiple_ranges_test.cc @@ -0,0 +1,96 @@ +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +class MultipleRangesFixture : public ::benchmark::Fixture { + public: + MultipleRangesFixture() + : expectedValues({{1, 3, 5}, + {1, 3, 8}, + {1, 3, 15}, + {2, 3, 5}, + {2, 3, 8}, + {2, 3, 15}, + {1, 4, 5}, + {1, 4, 8}, + {1, 4, 15}, + {2, 4, 5}, + {2, 4, 8}, + {2, 4, 15}, + {1, 7, 5}, + {1, 7, 8}, + {1, 7, 15}, + {2, 7, 5}, + {2, 7, 8}, + {2, 7, 15}, + {7, 6, 3}}) {} + + void SetUp(const ::benchmark::State& state) override { + std::vector ranges = {state.range(0), state.range(1), + state.range(2)}; + + assert(expectedValues.find(ranges) != expectedValues.end()); + + actualValues.insert(ranges); + } + + // NOTE: This is not TearDown as we want to check after _all_ runs are + // complete. + ~MultipleRangesFixture() override { + if (actualValues != expectedValues) { + std::cout << "EXPECTED\n"; + for (const auto& v : expectedValues) { + std::cout << "{"; + for (int64_t iv : v) { + std::cout << iv << ", "; + } + std::cout << "}\n"; + } + std::cout << "ACTUAL\n"; + for (const auto& v : actualValues) { + std::cout << "{"; + for (int64_t iv : v) { + std::cout << iv << ", "; + } + std::cout << "}\n"; + } + } + } + + std::set> expectedValues; + std::set> actualValues; +}; + +BENCHMARK_DEFINE_F(MultipleRangesFixture, Empty)(benchmark::State& state) { + for (auto _ : state) { + int64_t product = state.range(0) * state.range(1) * state.range(2); + for (int64_t x = 0; x < product; x++) { + benchmark::DoNotOptimize(x); + } + } +} + +BENCHMARK_REGISTER_F(MultipleRangesFixture, Empty) + ->RangeMultiplier(2) + ->Ranges({{1, 2}, {3, 7}, {5, 15}}) + ->Args({7, 6, 3}); + +void BM_CheckDefaultArgument(benchmark::State& state) { + // Test that the 'range()' without an argument is the same as 'range(0)'. + assert(state.range() == state.range(0)); + assert(state.range() != state.range(1)); + for (auto _ : state) { + } +} +BENCHMARK(BM_CheckDefaultArgument)->Ranges({{1, 5}, {6, 10}}); + +static void BM_MultipleRanges(benchmark::State& st) { + for (auto _ : st) { + } +} +BENCHMARK(BM_MultipleRanges)->Ranges({{5, 5}, {6, 6}}); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/options_test.cc b/third_party/benchmark/test/options_test.cc new file mode 100644 index 0000000..a1b209f --- /dev/null +++ b/third_party/benchmark/test/options_test.cc @@ -0,0 +1,77 @@ +#include +#include + +#include "benchmark/benchmark.h" + +#if defined(NDEBUG) +#undef NDEBUG +#endif +#include + +void BM_basic(benchmark::State& state) { + for (auto _ : state) { + } +} + +void BM_basic_slow(benchmark::State& state) { + std::chrono::milliseconds sleep_duration(state.range(0)); + for (auto _ : state) { + std::this_thread::sleep_for( + std::chrono::duration_cast(sleep_duration)); + } +} + +BENCHMARK(BM_basic); +BENCHMARK(BM_basic)->Arg(42); +BENCHMARK(BM_basic_slow)->Arg(10)->Unit(benchmark::kNanosecond); +BENCHMARK(BM_basic_slow)->Arg(100)->Unit(benchmark::kMicrosecond); +BENCHMARK(BM_basic_slow)->Arg(1000)->Unit(benchmark::kMillisecond); +BENCHMARK(BM_basic_slow)->Arg(1000)->Unit(benchmark::kSecond); +BENCHMARK(BM_basic)->Range(1, 8); +BENCHMARK(BM_basic)->RangeMultiplier(2)->Range(1, 8); +BENCHMARK(BM_basic)->DenseRange(10, 15); +BENCHMARK(BM_basic)->Args({42, 42}); +BENCHMARK(BM_basic)->Ranges({{64, 512}, {64, 512}}); +BENCHMARK(BM_basic)->MinTime(0.7); +BENCHMARK(BM_basic)->MinWarmUpTime(0.8); +BENCHMARK(BM_basic)->MinTime(0.1)->MinWarmUpTime(0.2); +BENCHMARK(BM_basic)->UseRealTime(); +BENCHMARK(BM_basic)->ThreadRange(2, 4); +BENCHMARK(BM_basic)->ThreadPerCpu(); +BENCHMARK(BM_basic)->Repetitions(3); +BENCHMARK(BM_basic) + ->RangeMultiplier(std::numeric_limits::max()) + ->Range(std::numeric_limits::min(), + std::numeric_limits::max()); + +// Negative ranges +BENCHMARK(BM_basic)->Range(-64, -1); +BENCHMARK(BM_basic)->RangeMultiplier(4)->Range(-8, 8); +BENCHMARK(BM_basic)->DenseRange(-2, 2, 1); +BENCHMARK(BM_basic)->Ranges({{-64, 1}, {-8, -1}}); + +void CustomArgs(benchmark::internal::Benchmark* b) { + for (int i = 0; i < 10; ++i) { + b->Arg(i); + } +} + +BENCHMARK(BM_basic)->Apply(CustomArgs); + +void BM_explicit_iteration_count(benchmark::State& state) { + // Test that benchmarks specified with an explicit iteration count are + // only run once. + static bool invoked_before = false; + assert(!invoked_before); + invoked_before = true; + + // Test that the requested iteration count is respected. + assert(state.max_iterations == 42); + for (auto _ : state) { + } + assert(state.iterations() == state.max_iterations); + assert(state.iterations() == 42); +} +BENCHMARK(BM_explicit_iteration_count)->Iterations(42); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/output_test.h b/third_party/benchmark/test/output_test.h new file mode 100644 index 0000000..c08fe1d --- /dev/null +++ b/third_party/benchmark/test/output_test.h @@ -0,0 +1,211 @@ +#ifndef TEST_OUTPUT_TEST_H +#define TEST_OUTPUT_TEST_H + +#undef NDEBUG +#include +#include +#include +#include +#include +#include +#include + +#include "../src/re.h" +#include "benchmark/benchmark.h" + +#define CONCAT2(x, y) x##y +#define CONCAT(x, y) CONCAT2(x, y) + +#define ADD_CASES(...) int CONCAT(dummy, __LINE__) = ::AddCases(__VA_ARGS__) + +#define SET_SUBSTITUTIONS(...) \ + int CONCAT(dummy, __LINE__) = ::SetSubstitutions(__VA_ARGS__) + +enum MatchRules { + MR_Default, // Skip non-matching lines until a match is found. + MR_Next, // Match must occur on the next line. + MR_Not // No line between the current position and the next match matches + // the regex +}; + +struct TestCase { + TestCase(std::string re, int rule = MR_Default); + + std::string regex_str; + int match_rule; + std::string substituted_regex; + std::shared_ptr regex; +}; + +enum TestCaseID { + TC_ConsoleOut, + TC_ConsoleErr, + TC_JSONOut, + TC_JSONErr, + TC_CSVOut, + TC_CSVErr, + + TC_NumID // PRIVATE +}; + +// Add a list of test cases to be run against the output specified by +// 'ID' +int AddCases(TestCaseID ID, std::initializer_list il); + +// Add or set a list of substitutions to be performed on constructed regex's +// See 'output_test_helper.cc' for a list of default substitutions. +int SetSubstitutions( + std::initializer_list> il); + +// Run all output tests. +void RunOutputTests(int argc, char* argv[]); + +// Count the number of 'pat' substrings in the 'haystack' string. +int SubstrCnt(const std::string& haystack, const std::string& pat); + +// Run registered benchmarks with file reporter enabled, and return the content +// outputted by the file reporter. +std::string GetFileReporterOutput(int argc, char* argv[]); + +// ========================================================================= // +// ------------------------- Results checking ------------------------------ // +// ========================================================================= // + +// Call this macro to register a benchmark for checking its results. This +// should be all that's needed. It subscribes a function to check the (CSV) +// results of a benchmark. This is done only after verifying that the output +// strings are really as expected. +// bm_name_pattern: a name or a regex pattern which will be matched against +// all the benchmark names. Matching benchmarks +// will be the subject of a call to checker_function +// checker_function: should be of type ResultsCheckFn (see below) +#define CHECK_BENCHMARK_RESULTS(bm_name_pattern, checker_function) \ + size_t CONCAT(dummy, __LINE__) = AddChecker(bm_name_pattern, checker_function) + +struct Results; +typedef std::function ResultsCheckFn; + +size_t AddChecker(const std::string& bm_name_pattern, const ResultsCheckFn& fn); + +// Class holding the results of a benchmark. +// It is passed in calls to checker functions. +struct Results { + // the benchmark name + std::string name; + // the benchmark fields + std::map values; + + Results(const std::string& n) : name(n) {} + + int NumThreads() const; + + double NumIterations() const; + + typedef enum { kCpuTime, kRealTime } BenchmarkTime; + + // get cpu_time or real_time in seconds + double GetTime(BenchmarkTime which) const; + + // get the real_time duration of the benchmark in seconds. + // it is better to use fuzzy float checks for this, as the float + // ASCII formatting is lossy. + double DurationRealTime() const { + return NumIterations() * GetTime(kRealTime); + } + // get the cpu_time duration of the benchmark in seconds + double DurationCPUTime() const { return NumIterations() * GetTime(kCpuTime); } + + // get the string for a result by name, or nullptr if the name + // is not found + const std::string* Get(const std::string& entry_name) const { + auto it = values.find(entry_name); + if (it == values.end()) return nullptr; + return &it->second; + } + + // get a result by name, parsed as a specific type. + // NOTE: for counters, use GetCounterAs instead. + template + T GetAs(const std::string& entry_name) const; + + // counters are written as doubles, so they have to be read first + // as a double, and only then converted to the asked type. + template + T GetCounterAs(const std::string& entry_name) const { + double dval = GetAs(entry_name); + T tval = static_cast(dval); + return tval; + } +}; + +template +T Results::GetAs(const std::string& entry_name) const { + auto* sv = Get(entry_name); + BM_CHECK(sv != nullptr && !sv->empty()); + std::stringstream ss; + ss << *sv; + T out; + ss >> out; + BM_CHECK(!ss.fail()); + return out; +} + +//---------------------------------- +// Macros to help in result checking. Do not use them with arguments causing +// side-effects. + +// clang-format off + +#define CHECK_RESULT_VALUE_IMPL(entry, getfn, var_type, var_name, relationship, value) \ + CONCAT(BM_CHECK_, relationship) \ + (entry.getfn< var_type >(var_name), (value)) << "\n" \ + << __FILE__ << ":" << __LINE__ << ": " << (entry).name << ":\n" \ + << __FILE__ << ":" << __LINE__ << ": " \ + << "expected (" << #var_type << ")" << (var_name) \ + << "=" << (entry).getfn< var_type >(var_name) \ + << " to be " #relationship " to " << (value) << "\n" + +// check with tolerance. eps_factor is the tolerance window, which is +// interpreted relative to value (eg, 0.1 means 10% of value). +#define CHECK_FLOAT_RESULT_VALUE_IMPL(entry, getfn, var_type, var_name, relationship, value, eps_factor) \ + CONCAT(BM_CHECK_FLOAT_, relationship) \ + (entry.getfn< var_type >(var_name), (value), (eps_factor) * (value)) << "\n" \ + << __FILE__ << ":" << __LINE__ << ": " << (entry).name << ":\n" \ + << __FILE__ << ":" << __LINE__ << ": " \ + << "expected (" << #var_type << ")" << (var_name) \ + << "=" << (entry).getfn< var_type >(var_name) \ + << " to be " #relationship " to " << (value) << "\n" \ + << __FILE__ << ":" << __LINE__ << ": " \ + << "with tolerance of " << (eps_factor) * (value) \ + << " (" << (eps_factor)*100. << "%), " \ + << "but delta was " << ((entry).getfn< var_type >(var_name) - (value)) \ + << " (" << (((entry).getfn< var_type >(var_name) - (value)) \ + / \ + ((value) > 1.e-5 || value < -1.e-5 ? value : 1.e-5)*100.) \ + << "%)" + +#define CHECK_RESULT_VALUE(entry, var_type, var_name, relationship, value) \ + CHECK_RESULT_VALUE_IMPL(entry, GetAs, var_type, var_name, relationship, value) + +#define CHECK_COUNTER_VALUE(entry, var_type, var_name, relationship, value) \ + CHECK_RESULT_VALUE_IMPL(entry, GetCounterAs, var_type, var_name, relationship, value) + +#define CHECK_FLOAT_RESULT_VALUE(entry, var_name, relationship, value, eps_factor) \ + CHECK_FLOAT_RESULT_VALUE_IMPL(entry, GetAs, double, var_name, relationship, value, eps_factor) + +#define CHECK_FLOAT_COUNTER_VALUE(entry, var_name, relationship, value, eps_factor) \ + CHECK_FLOAT_RESULT_VALUE_IMPL(entry, GetCounterAs, double, var_name, relationship, value, eps_factor) + +// clang-format on + +// ========================================================================= // +// --------------------------- Misc Utilities ------------------------------ // +// ========================================================================= // + +namespace { + +const char* const dec_re = "[0-9]*[.]?[0-9]+([eE][-+][0-9]+)?"; + +} // end namespace + +#endif // TEST_OUTPUT_TEST_H diff --git a/third_party/benchmark/test/output_test_helper.cc b/third_party/benchmark/test/output_test_helper.cc new file mode 100644 index 0000000..265f28a --- /dev/null +++ b/third_party/benchmark/test/output_test_helper.cc @@ -0,0 +1,520 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/benchmark_api_internal.h" +#include "../src/check.h" // NOTE: check.h is for internal use only! +#include "../src/log.h" // NOTE: log.h is for internal use only +#include "../src/re.h" // NOTE: re.h is for internal use only +#include "output_test.h" + +// ========================================================================= // +// ------------------------------ Internals -------------------------------- // +// ========================================================================= // +namespace internal { +namespace { + +using TestCaseList = std::vector; + +// Use a vector because the order elements are added matters during iteration. +// std::map/unordered_map don't guarantee that. +// For example: +// SetSubstitutions({{"%HelloWorld", "Hello"}, {"%Hello", "Hi"}}); +// Substitute("%HelloWorld") // Always expands to Hello. +using SubMap = std::vector>; + +TestCaseList& GetTestCaseList(TestCaseID ID) { + // Uses function-local statics to ensure initialization occurs + // before first use. + static TestCaseList lists[TC_NumID]; + return lists[ID]; +} + +SubMap& GetSubstitutions() { + // Don't use 'dec_re' from header because it may not yet be initialized. + // clang-format off + static std::string safe_dec_re = "[0-9]*[.]?[0-9]+([eE][-+][0-9]+)?"; + static std::string time_re = "([0-9]+[.])?[0-9]+"; + static std::string percentage_re = "[0-9]+[.][0-9]{2}"; + static SubMap map = { + {"%float", "[0-9]*[.]?[0-9]+([eE][-+][0-9]+)?"}, + // human-readable float + {"%hrfloat", "[0-9]*[.]?[0-9]+([eE][-+][0-9]+)?[kKMGTPEZYmunpfazy]?i?"}, + {"%percentage", percentage_re}, + {"%int", "[ ]*[0-9]+"}, + {" %s ", "[ ]+"}, + {"%time", "[ ]*" + time_re + "[ ]+ns"}, + {"%console_report", "[ ]*" + time_re + "[ ]+ns [ ]*" + time_re + "[ ]+ns [ ]*[0-9]+"}, + {"%console_percentage_report", "[ ]*" + percentage_re + "[ ]+% [ ]*" + percentage_re + "[ ]+% [ ]*[0-9]+"}, + {"%console_us_report", "[ ]*" + time_re + "[ ]+us [ ]*" + time_re + "[ ]+us [ ]*[0-9]+"}, + {"%console_ms_report", "[ ]*" + time_re + "[ ]+ms [ ]*" + time_re + "[ ]+ms [ ]*[0-9]+"}, + {"%console_s_report", "[ ]*" + time_re + "[ ]+s [ ]*" + time_re + "[ ]+s [ ]*[0-9]+"}, + {"%console_time_only_report", "[ ]*" + time_re + "[ ]+ns [ ]*" + time_re + "[ ]+ns"}, + {"%console_us_report", "[ ]*" + time_re + "[ ]+us [ ]*" + time_re + "[ ]+us [ ]*[0-9]+"}, + {"%console_us_time_only_report", "[ ]*" + time_re + "[ ]+us [ ]*" + time_re + "[ ]+us"}, + {"%csv_header", + "name,iterations,real_time,cpu_time,time_unit,bytes_per_second," + "items_per_second,label,error_occurred,error_message"}, + {"%csv_report", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ns,,,,,"}, + {"%csv_us_report", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",us,,,,,"}, + {"%csv_ms_report", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ms,,,,,"}, + {"%csv_s_report", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",s,,,,,"}, + {"%csv_cv_report", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",,,,,,"}, + {"%csv_bytes_report", + "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ns," + safe_dec_re + ",,,,"}, + {"%csv_items_report", + "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ns,," + safe_dec_re + ",,,"}, + {"%csv_bytes_items_report", + "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ns," + safe_dec_re + + "," + safe_dec_re + ",,,"}, + {"%csv_label_report_begin", "[0-9]+," + safe_dec_re + "," + safe_dec_re + ",ns,,,"}, + {"%csv_label_report_end", ",,"}}; + // clang-format on + return map; +} + +std::string PerformSubstitutions(std::string source) { + SubMap const& subs = GetSubstitutions(); + using SizeT = std::string::size_type; + for (auto const& KV : subs) { + SizeT pos; + SizeT next_start = 0; + while ((pos = source.find(KV.first, next_start)) != std::string::npos) { + next_start = pos + KV.second.size(); + source.replace(pos, KV.first.size(), KV.second); + } + } + return source; +} + +void CheckCase(std::stringstream& remaining_output, TestCase const& TC, + TestCaseList const& not_checks) { + std::string first_line; + bool on_first = true; + std::string line; + while (remaining_output.eof() == false) { + BM_CHECK(remaining_output.good()); + std::getline(remaining_output, line); + if (on_first) { + first_line = line; + on_first = false; + } + for (const auto& NC : not_checks) { + BM_CHECK(!NC.regex->Match(line)) + << "Unexpected match for line \"" << line << "\" for MR_Not regex \"" + << NC.regex_str << "\"" + << "\n actual regex string \"" << TC.substituted_regex << "\"" + << "\n started matching near: " << first_line; + } + if (TC.regex->Match(line)) return; + BM_CHECK(TC.match_rule != MR_Next) + << "Expected line \"" << line << "\" to match regex \"" << TC.regex_str + << "\"" + << "\n actual regex string \"" << TC.substituted_regex << "\"" + << "\n started matching near: " << first_line; + } + BM_CHECK(remaining_output.eof() == false) + << "End of output reached before match for regex \"" << TC.regex_str + << "\" was found" + << "\n actual regex string \"" << TC.substituted_regex << "\"" + << "\n started matching near: " << first_line; +} + +void CheckCases(TestCaseList const& checks, std::stringstream& output) { + std::vector not_checks; + for (size_t i = 0; i < checks.size(); ++i) { + const auto& TC = checks[i]; + if (TC.match_rule == MR_Not) { + not_checks.push_back(TC); + continue; + } + CheckCase(output, TC, not_checks); + not_checks.clear(); + } +} + +class TestReporter : public benchmark::BenchmarkReporter { + public: + TestReporter(std::vector reps) + : reporters_(std::move(reps)) {} + + bool ReportContext(const Context& context) override { + bool last_ret = false; + bool first = true; + for (auto rep : reporters_) { + bool new_ret = rep->ReportContext(context); + BM_CHECK(first || new_ret == last_ret) + << "Reports return different values for ReportContext"; + first = false; + last_ret = new_ret; + } + (void)first; + return last_ret; + } + + void ReportRuns(const std::vector& report) override { + for (auto rep : reporters_) rep->ReportRuns(report); + } + void Finalize() override { + for (auto rep : reporters_) rep->Finalize(); + } + + private: + std::vector reporters_; +}; +} // namespace + +} // end namespace internal + +// ========================================================================= // +// -------------------------- Results checking ----------------------------- // +// ========================================================================= // + +namespace internal { + +// Utility class to manage subscribers for checking benchmark results. +// It works by parsing the CSV output to read the results. +class ResultsChecker { + public: + struct PatternAndFn : public TestCase { // reusing TestCase for its regexes + PatternAndFn(const std::string& rx, ResultsCheckFn fn_) + : TestCase(rx), fn(std::move(fn_)) {} + ResultsCheckFn fn; + }; + + std::vector check_patterns; + std::vector results; + std::vector field_names; + + void Add(const std::string& entry_pattern, const ResultsCheckFn& fn); + + void CheckResults(std::stringstream& output); + + private: + void SetHeader_(const std::string& csv_header); + void SetValues_(const std::string& entry_csv_line); + + std::vector SplitCsv_(const std::string& line); +}; + +// store the static ResultsChecker in a function to prevent initialization +// order problems +ResultsChecker& GetResultsChecker() { + static ResultsChecker rc; + return rc; +} + +// add a results checker for a benchmark +void ResultsChecker::Add(const std::string& entry_pattern, + const ResultsCheckFn& fn) { + check_patterns.emplace_back(entry_pattern, fn); +} + +// check the results of all subscribed benchmarks +void ResultsChecker::CheckResults(std::stringstream& output) { + // first reset the stream to the start + { + auto start = std::stringstream::pos_type(0); + // clear before calling tellg() + output.clear(); + // seek to zero only when needed + if (output.tellg() > start) output.seekg(start); + // and just in case + output.clear(); + } + // now go over every line and publish it to the ResultsChecker + std::string line; + bool on_first = true; + while (output.eof() == false) { + BM_CHECK(output.good()); + std::getline(output, line); + if (on_first) { + SetHeader_(line); // this is important + on_first = false; + continue; + } + SetValues_(line); + } + // finally we can call the subscribed check functions + for (const auto& p : check_patterns) { + BM_VLOG(2) << "--------------------------------\n"; + BM_VLOG(2) << "checking for benchmarks matching " << p.regex_str << "...\n"; + for (const auto& r : results) { + if (!p.regex->Match(r.name)) { + BM_VLOG(2) << p.regex_str << " is not matched by " << r.name << "\n"; + continue; + } + BM_VLOG(2) << p.regex_str << " is matched by " << r.name << "\n"; + BM_VLOG(1) << "Checking results of " << r.name << ": ... \n"; + p.fn(r); + BM_VLOG(1) << "Checking results of " << r.name << ": OK.\n"; + } + } +} + +// prepare for the names in this header +void ResultsChecker::SetHeader_(const std::string& csv_header) { + field_names = SplitCsv_(csv_header); +} + +// set the values for a benchmark +void ResultsChecker::SetValues_(const std::string& entry_csv_line) { + if (entry_csv_line.empty()) return; // some lines are empty + BM_CHECK(!field_names.empty()); + auto vals = SplitCsv_(entry_csv_line); + BM_CHECK_EQ(vals.size(), field_names.size()); + results.emplace_back(vals[0]); // vals[0] is the benchmark name + auto& entry = results.back(); + for (size_t i = 1, e = vals.size(); i < e; ++i) { + entry.values[field_names[i]] = vals[i]; + } +} + +// a quick'n'dirty csv splitter (eliminating quotes) +std::vector ResultsChecker::SplitCsv_(const std::string& line) { + std::vector out; + if (line.empty()) return out; + if (!field_names.empty()) out.reserve(field_names.size()); + size_t prev = 0, pos = line.find_first_of(','), curr = pos; + while (pos != line.npos) { + BM_CHECK(curr > 0); + if (line[prev] == '"') ++prev; + if (line[curr - 1] == '"') --curr; + out.push_back(line.substr(prev, curr - prev)); + prev = pos + 1; + pos = line.find_first_of(',', pos + 1); + curr = pos; + } + curr = line.size(); + if (line[prev] == '"') ++prev; + if (line[curr - 1] == '"') --curr; + out.push_back(line.substr(prev, curr - prev)); + return out; +} + +} // end namespace internal + +size_t AddChecker(const std::string& bm_name, const ResultsCheckFn& fn) { + auto& rc = internal::GetResultsChecker(); + rc.Add(bm_name, fn); + return rc.results.size(); +} + +int Results::NumThreads() const { + auto pos = name.find("/threads:"); + if (pos == name.npos) return 1; + auto end = name.find('/', pos + 9); + std::stringstream ss; + ss << name.substr(pos + 9, end); + int num = 1; + ss >> num; + BM_CHECK(!ss.fail()); + return num; +} + +double Results::NumIterations() const { return GetAs("iterations"); } + +double Results::GetTime(BenchmarkTime which) const { + BM_CHECK(which == kCpuTime || which == kRealTime); + const char* which_str = which == kCpuTime ? "cpu_time" : "real_time"; + double val = GetAs(which_str); + auto unit = Get("time_unit"); + BM_CHECK(unit); + if (*unit == "ns") { + return val * 1.e-9; + } + if (*unit == "us") { + return val * 1.e-6; + } + if (*unit == "ms") { + return val * 1.e-3; + } + if (*unit == "s") { + return val; + } + BM_CHECK(1 == 0) << "unknown time unit: " << *unit; + return 0; +} + +// ========================================================================= // +// -------------------------- Public API Definitions------------------------ // +// ========================================================================= // + +TestCase::TestCase(std::string re, int rule) + : regex_str(std::move(re)), + match_rule(rule), + substituted_regex(internal::PerformSubstitutions(regex_str)), + regex(std::make_shared()) { + std::string err_str; + regex->Init(substituted_regex, &err_str); + BM_CHECK(err_str.empty()) + << "Could not construct regex \"" << substituted_regex << "\"" + << "\n originally \"" << regex_str << "\"" + << "\n got error: " << err_str; +} + +int AddCases(TestCaseID ID, std::initializer_list il) { + auto& L = internal::GetTestCaseList(ID); + L.insert(L.end(), il); + return 0; +} + +int SetSubstitutions( + std::initializer_list> il) { + auto& subs = internal::GetSubstitutions(); + for (auto KV : il) { + bool exists = false; + KV.second = internal::PerformSubstitutions(KV.second); + for (auto& EKV : subs) { + if (EKV.first == KV.first) { + EKV.second = std::move(KV.second); + exists = true; + break; + } + } + if (!exists) subs.push_back(std::move(KV)); + } + return 0; +} + +// Disable deprecated warnings temporarily because we need to reference +// CSVReporter but don't want to trigger -Werror=-Wdeprecated-declarations +BENCHMARK_DISABLE_DEPRECATED_WARNING + +void RunOutputTests(int argc, char* argv[]) { + using internal::GetTestCaseList; + benchmark::Initialize(&argc, argv); + auto options = benchmark::internal::GetOutputOptions(/*force_no_color*/ true); + benchmark::ConsoleReporter CR(options); + benchmark::JSONReporter JR; + benchmark::CSVReporter CSVR; + struct ReporterTest { + std::string name; + std::vector& output_cases; + std::vector& error_cases; + benchmark::BenchmarkReporter& reporter; + std::stringstream out_stream; + std::stringstream err_stream; + + ReporterTest(const std::string& n, std::vector& out_tc, + std::vector& err_tc, + benchmark::BenchmarkReporter& br) + : name(n), output_cases(out_tc), error_cases(err_tc), reporter(br) { + reporter.SetOutputStream(&out_stream); + reporter.SetErrorStream(&err_stream); + } + } TestCases[] = { + {std::string("ConsoleReporter"), GetTestCaseList(TC_ConsoleOut), + GetTestCaseList(TC_ConsoleErr), CR}, + {std::string("JSONReporter"), GetTestCaseList(TC_JSONOut), + GetTestCaseList(TC_JSONErr), JR}, + {std::string("CSVReporter"), GetTestCaseList(TC_CSVOut), + GetTestCaseList(TC_CSVErr), CSVR}, + }; + + // Create the test reporter and run the benchmarks. + std::cout << "Running benchmarks...\n"; + internal::TestReporter test_rep({&CR, &JR, &CSVR}); + benchmark::RunSpecifiedBenchmarks(&test_rep); + + for (auto& rep_test : TestCases) { + std::string msg = + std::string("\nTesting ") + rep_test.name + std::string(" Output\n"); + std::string banner(msg.size() - 1, '-'); + std::cout << banner << msg << banner << "\n"; + + std::cerr << rep_test.err_stream.str(); + std::cout << rep_test.out_stream.str(); + + internal::CheckCases(rep_test.error_cases, rep_test.err_stream); + internal::CheckCases(rep_test.output_cases, rep_test.out_stream); + + std::cout << "\n"; + } + + // now that we know the output is as expected, we can dispatch + // the checks to subscribees. + auto& csv = TestCases[2]; + // would use == but gcc spits a warning + BM_CHECK(csv.name == std::string("CSVReporter")); + internal::GetResultsChecker().CheckResults(csv.out_stream); +} + +BENCHMARK_RESTORE_DEPRECATED_WARNING + +int SubstrCnt(const std::string& haystack, const std::string& pat) { + if (pat.length() == 0) return 0; + int count = 0; + for (size_t offset = haystack.find(pat); offset != std::string::npos; + offset = haystack.find(pat, offset + pat.length())) + ++count; + return count; +} + +static char ToHex(int ch) { + return ch < 10 ? static_cast('0' + ch) + : static_cast('a' + (ch - 10)); +} + +static char RandomHexChar() { + static std::mt19937 rd{std::random_device{}()}; + static std::uniform_int_distribution mrand{0, 15}; + return ToHex(mrand(rd)); +} + +static std::string GetRandomFileName() { + std::string model = "test.%%%%%%"; + for (auto& ch : model) { + if (ch == '%') ch = RandomHexChar(); + } + return model; +} + +static bool FileExists(std::string const& name) { + std::ifstream in(name.c_str()); + return in.good(); +} + +static std::string GetTempFileName() { + // This function attempts to avoid race conditions where two tests + // create the same file at the same time. However, it still introduces races + // similar to tmpnam. + int retries = 3; + while (--retries) { + std::string name = GetRandomFileName(); + if (!FileExists(name)) return name; + } + std::cerr << "Failed to create unique temporary file name" << std::endl; + std::abort(); +} + +std::string GetFileReporterOutput(int argc, char* argv[]) { + std::vector new_argv(argv, argv + argc); + assert(static_cast(argc) == new_argv.size()); + + std::string tmp_file_name = GetTempFileName(); + std::cout << "Will be using this as the tmp file: " << tmp_file_name << '\n'; + + std::string tmp = "--benchmark_out="; + tmp += tmp_file_name; + new_argv.emplace_back(const_cast(tmp.c_str())); + + argc = int(new_argv.size()); + + benchmark::Initialize(&argc, new_argv.data()); + benchmark::RunSpecifiedBenchmarks(); + + // Read the output back from the file, and delete the file. + std::ifstream tmp_stream(tmp_file_name); + std::string output = std::string((std::istreambuf_iterator(tmp_stream)), + std::istreambuf_iterator()); + std::remove(tmp_file_name.c_str()); + + return output; +} diff --git a/third_party/benchmark/test/perf_counters_gtest.cc b/third_party/benchmark/test/perf_counters_gtest.cc new file mode 100644 index 0000000..2e63049 --- /dev/null +++ b/third_party/benchmark/test/perf_counters_gtest.cc @@ -0,0 +1,307 @@ +#include +#include + +#include "../src/perf_counters.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#ifndef GTEST_SKIP +struct MsgHandler { + void operator=(std::ostream&) {} +}; +#define GTEST_SKIP() return MsgHandler() = std::cout +#endif + +using benchmark::internal::PerfCounters; +using benchmark::internal::PerfCountersMeasurement; +using benchmark::internal::PerfCounterValues; +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Lt; + +namespace { +const char kGenericPerfEvent1[] = "CYCLES"; +const char kGenericPerfEvent2[] = "INSTRUCTIONS"; + +TEST(PerfCountersTest, Init) { + EXPECT_EQ(PerfCounters::Initialize(), PerfCounters::kSupported); +} + +TEST(PerfCountersTest, OneCounter) { + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Performance counters not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + EXPECT_EQ(PerfCounters::Create({kGenericPerfEvent1}).num_counters(), 1); +} + +TEST(PerfCountersTest, NegativeTest) { + if (!PerfCounters::kSupported) { + EXPECT_FALSE(PerfCounters::Initialize()); + return; + } + EXPECT_TRUE(PerfCounters::Initialize()); + // Safety checks + // Create() will always create a valid object, even if passed no or + // wrong arguments as the new behavior is to warn and drop unsupported + // counters + EXPECT_EQ(PerfCounters::Create({}).num_counters(), 0); + EXPECT_EQ(PerfCounters::Create({""}).num_counters(), 0); + EXPECT_EQ(PerfCounters::Create({"not a counter name"}).num_counters(), 0); + { + // Try sneaking in a bad egg to see if it is filtered out. The + // number of counters has to be two, not zero + auto counter = + PerfCounters::Create({kGenericPerfEvent2, "", kGenericPerfEvent1}); + EXPECT_EQ(counter.num_counters(), 2); + EXPECT_EQ(counter.names(), std::vector( + {kGenericPerfEvent2, kGenericPerfEvent1})); + } + { + // Try sneaking in an outrageous counter, like a fat finger mistake + auto counter = PerfCounters::Create( + {kGenericPerfEvent2, "not a counter name", kGenericPerfEvent1}); + EXPECT_EQ(counter.num_counters(), 2); + EXPECT_EQ(counter.names(), std::vector( + {kGenericPerfEvent2, kGenericPerfEvent1})); + } + { + // Finally try a golden input - it should like both of them + EXPECT_EQ(PerfCounters::Create({kGenericPerfEvent1, kGenericPerfEvent2}) + .num_counters(), + 2); + } + { + // Add a bad apple in the end of the chain to check the edges + auto counter = PerfCounters::Create( + {kGenericPerfEvent1, kGenericPerfEvent2, "bad event name"}); + EXPECT_EQ(counter.num_counters(), 2); + EXPECT_EQ(counter.names(), std::vector( + {kGenericPerfEvent1, kGenericPerfEvent2})); + } +} + +TEST(PerfCountersTest, Read1Counter) { + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + auto counters = PerfCounters::Create({kGenericPerfEvent1}); + EXPECT_EQ(counters.num_counters(), 1); + PerfCounterValues values1(1); + EXPECT_TRUE(counters.Snapshot(&values1)); + EXPECT_GT(values1[0], 0); + PerfCounterValues values2(1); + EXPECT_TRUE(counters.Snapshot(&values2)); + EXPECT_GT(values2[0], 0); + EXPECT_GT(values2[0], values1[0]); +} + +TEST(PerfCountersTest, Read2Counters) { + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + auto counters = + PerfCounters::Create({kGenericPerfEvent1, kGenericPerfEvent2}); + EXPECT_EQ(counters.num_counters(), 2); + PerfCounterValues values1(2); + EXPECT_TRUE(counters.Snapshot(&values1)); + EXPECT_GT(values1[0], 0); + EXPECT_GT(values1[1], 0); + PerfCounterValues values2(2); + EXPECT_TRUE(counters.Snapshot(&values2)); + EXPECT_GT(values2[0], 0); + EXPECT_GT(values2[1], 0); +} + +TEST(PerfCountersTest, ReopenExistingCounters) { + // This test works in recent and old Intel hardware, Pixel 3, and Pixel 6. + // However we cannot make assumptions beyond 2 HW counters due to Pixel 6. + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + std::vector kMetrics({kGenericPerfEvent1}); + std::vector counters(2); + for (auto& counter : counters) { + counter = PerfCounters::Create(kMetrics); + } + PerfCounterValues values(1); + EXPECT_TRUE(counters[0].Snapshot(&values)); + EXPECT_TRUE(counters[1].Snapshot(&values)); +} + +TEST(PerfCountersTest, CreateExistingMeasurements) { + // The test works (i.e. causes read to fail) for the assumptions + // about hardware capabilities (i.e. small number (2) hardware + // counters) at this date, + // the same as previous test ReopenExistingCounters. + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + + // This means we will try 10 counters but we can only guarantee + // for sure at this time that only 3 will work. Perhaps in the future + // we could use libpfm to query for the hardware limits on this + // particular platform. + const int kMaxCounters = 10; + const int kMinValidCounters = 2; + + // Let's use a ubiquitous counter that is guaranteed to work + // on all platforms + const std::vector kMetrics{"cycles"}; + + // Cannot create a vector of actual objects because the + // copy constructor of PerfCounters is deleted - and so is + // implicitly deleted on PerfCountersMeasurement too + std::vector> + perf_counter_measurements; + + perf_counter_measurements.reserve(kMaxCounters); + for (int j = 0; j < kMaxCounters; ++j) { + perf_counter_measurements.emplace_back( + new PerfCountersMeasurement(kMetrics)); + } + + std::vector> measurements; + + // Start all counters together to see if they hold + size_t max_counters = kMaxCounters; + for (size_t i = 0; i < kMaxCounters; ++i) { + auto& counter(*perf_counter_measurements[i]); + EXPECT_EQ(counter.num_counters(), 1); + if (!counter.Start()) { + max_counters = i; + break; + }; + } + + ASSERT_GE(max_counters, kMinValidCounters); + + // Start all together + for (size_t i = 0; i < max_counters; ++i) { + auto& counter(*perf_counter_measurements[i]); + EXPECT_TRUE(counter.Stop(measurements) || (i >= kMinValidCounters)); + } + + // Start/stop individually + for (size_t i = 0; i < max_counters; ++i) { + auto& counter(*perf_counter_measurements[i]); + measurements.clear(); + counter.Start(); + EXPECT_TRUE(counter.Stop(measurements) || (i >= kMinValidCounters)); + } +} + +// We try to do some meaningful work here but the compiler +// insists in optimizing away our loop so we had to add a +// no-optimize macro. In case it fails, we added some entropy +// to this pool as well. + +BENCHMARK_DONT_OPTIMIZE size_t do_work() { + static std::mt19937 rd{std::random_device{}()}; + static std::uniform_int_distribution mrand(0, 10); + const size_t kNumLoops = 1000000; + size_t sum = 0; + for (size_t j = 0; j < kNumLoops; ++j) { + sum += mrand(rd); + } + benchmark::DoNotOptimize(sum); + return sum; +} + +void measure(size_t threadcount, PerfCounterValues* before, + PerfCounterValues* after) { + BM_CHECK_NE(before, nullptr); + BM_CHECK_NE(after, nullptr); + std::vector threads(threadcount); + auto work = [&]() { BM_CHECK(do_work() > 1000); }; + + // We need to first set up the counters, then start the threads, so the + // threads would inherit the counters. But later, we need to first destroy + // the thread pool (so all the work finishes), then measure the counters. So + // the scopes overlap, and we need to explicitly control the scope of the + // threadpool. + auto counters = + PerfCounters::Create({kGenericPerfEvent1, kGenericPerfEvent2}); + for (auto& t : threads) t = std::thread(work); + counters.Snapshot(before); + for (auto& t : threads) t.join(); + counters.Snapshot(after); +} + +TEST(PerfCountersTest, MultiThreaded) { + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported."; + } + EXPECT_TRUE(PerfCounters::Initialize()); + PerfCounterValues before(2); + PerfCounterValues after(2); + + // Notice that this test will work even if we taskset it to a single CPU + // In this case the threads will run sequentially + // Start two threads and measure the number of combined cycles and + // instructions + measure(2, &before, &after); + std::vector Elapsed2Threads{ + static_cast(after[0] - before[0]), + static_cast(after[1] - before[1])}; + + // Start four threads and measure the number of combined cycles and + // instructions + measure(4, &before, &after); + std::vector Elapsed4Threads{ + static_cast(after[0] - before[0]), + static_cast(after[1] - before[1])}; + + // The following expectations fail (at least on a beefy workstation with lots + // of cpus) - it seems that in some circumstances the runtime of 4 threads + // can even be better than with 2. + // So instead of expecting 4 threads to be slower, let's just make sure they + // do not differ too much in general (one is not more than 10x than the + // other). + EXPECT_THAT(Elapsed4Threads[0] / Elapsed2Threads[0], AllOf(Gt(0.1), Lt(10))); + EXPECT_THAT(Elapsed4Threads[1] / Elapsed2Threads[1], AllOf(Gt(0.1), Lt(10))); +} + +TEST(PerfCountersTest, HardwareLimits) { + // The test works (i.e. causes read to fail) for the assumptions + // about hardware capabilities (i.e. small number (3-4) hardware + // counters) at this date, + // the same as previous test ReopenExistingCounters. + if (!PerfCounters::kSupported) { + GTEST_SKIP() << "Test skipped because libpfm is not supported.\n"; + } + EXPECT_TRUE(PerfCounters::Initialize()); + + // Taken from `perf list`, but focusses only on those HW events that actually + // were reported when running `sudo perf stat -a sleep 10`, intersected over + // several platforms. All HW events listed in the first command not reported + // in the second seem to not work. This is sad as we don't really get to test + // the grouping here (groups can contain up to 6 members)... + std::vector counter_names{ + "cycles", // leader + "instructions", // + "branch-misses", // + }; + + // In the off-chance that some of these values are not supported, + // we filter them out so the test will complete without failure + // albeit it might not actually test the grouping on that platform + std::vector valid_names; + for (const std::string& name : counter_names) { + if (PerfCounters::IsCounterSupported(name)) { + valid_names.push_back(name); + } + } + PerfCountersMeasurement counter(valid_names); + + std::vector> measurements; + + counter.Start(); + EXPECT_TRUE(counter.Stop(measurements)); +} + +} // namespace diff --git a/third_party/benchmark/test/perf_counters_test.cc b/third_party/benchmark/test/perf_counters_test.cc new file mode 100644 index 0000000..3cc593e --- /dev/null +++ b/third_party/benchmark/test/perf_counters_test.cc @@ -0,0 +1,92 @@ +#include +#undef NDEBUG + +#include "../src/commandlineflags.h" +#include "../src/perf_counters.h" +#include "benchmark/benchmark.h" +#include "output_test.h" + +namespace benchmark { + +BM_DECLARE_string(benchmark_perf_counters); + +} // namespace benchmark + +static void BM_Simple(benchmark::State& state) { + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_Simple); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Simple\",$"}}); + +const int kIters = 1000000; + +void BM_WithoutPauseResume(benchmark::State& state) { + int n = 0; + + for (auto _ : state) { + for (auto i = 0; i < kIters; ++i) { + n = 1 - n; + benchmark::DoNotOptimize(n); + } + } +} + +BENCHMARK(BM_WithoutPauseResume); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_WithoutPauseResume\",$"}}); + +void BM_WithPauseResume(benchmark::State& state) { + int m = 0, n = 0; + + for (auto _ : state) { + for (auto i = 0; i < kIters; ++i) { + n = 1 - n; + benchmark::DoNotOptimize(n); + } + + state.PauseTiming(); + for (auto j = 0; j < kIters; ++j) { + m = 1 - m; + benchmark::DoNotOptimize(m); + } + state.ResumeTiming(); + } +} + +BENCHMARK(BM_WithPauseResume); + +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_WithPauseResume\",$"}}); + +static void CheckSimple(Results const& e) { + CHECK_COUNTER_VALUE(e, double, "CYCLES", GT, 0); +} + +double withoutPauseResumeInstrCount = 0.0; +double withPauseResumeInstrCount = 0.0; + +static void SaveInstrCountWithoutResume(Results const& e) { + withoutPauseResumeInstrCount = e.GetAs("INSTRUCTIONS"); +} + +static void SaveInstrCountWithResume(Results const& e) { + withPauseResumeInstrCount = e.GetAs("INSTRUCTIONS"); +} + +CHECK_BENCHMARK_RESULTS("BM_Simple", &CheckSimple); +CHECK_BENCHMARK_RESULTS("BM_WithoutPauseResume", &SaveInstrCountWithoutResume); +CHECK_BENCHMARK_RESULTS("BM_WithPauseResume", &SaveInstrCountWithResume); + +int main(int argc, char* argv[]) { + if (!benchmark::internal::PerfCounters::kSupported) { + return 0; + } + benchmark::FLAGS_benchmark_perf_counters = "CYCLES,INSTRUCTIONS"; + benchmark::internal::PerfCounters::Initialize(); + RunOutputTests(argc, argv); + + BM_CHECK_GT(withPauseResumeInstrCount, kIters); + BM_CHECK_GT(withoutPauseResumeInstrCount, kIters); + BM_CHECK_LT(withPauseResumeInstrCount, 1.5 * withoutPauseResumeInstrCount); +} diff --git a/third_party/benchmark/test/profiler_manager_gtest.cc b/third_party/benchmark/test/profiler_manager_gtest.cc new file mode 100644 index 0000000..434e4ec --- /dev/null +++ b/third_party/benchmark/test/profiler_manager_gtest.cc @@ -0,0 +1,42 @@ +#include + +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" + +namespace { + +class TestProfilerManager : public benchmark::ProfilerManager { + public: + void AfterSetupStart() override { ++start_called; } + void BeforeTeardownStop() override { ++stop_called; } + + int start_called = 0; + int stop_called = 0; +}; + +void BM_empty(benchmark::State& state) { + for (auto _ : state) { + auto iterations = state.iterations(); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_empty); + +TEST(ProfilerManager, ReregisterManager) { +#if GTEST_HAS_DEATH_TEST + // Tests only runnable in debug mode (when BM_CHECK is enabled). +#ifndef NDEBUG +#ifndef TEST_BENCHMARK_LIBRARY_HAS_NO_ASSERTIONS + ASSERT_DEATH_IF_SUPPORTED( + { + std::unique_ptr pm(new TestProfilerManager()); + benchmark::RegisterProfilerManager(pm.get()); + benchmark::RegisterProfilerManager(pm.get()); + }, + "RegisterProfilerManager"); +#endif +#endif +#endif +} + +} // namespace diff --git a/third_party/benchmark/test/profiler_manager_test.cc b/third_party/benchmark/test/profiler_manager_test.cc new file mode 100644 index 0000000..3b08a60 --- /dev/null +++ b/third_party/benchmark/test/profiler_manager_test.cc @@ -0,0 +1,50 @@ +// FIXME: WIP + +#include + +#include "benchmark/benchmark.h" +#include "output_test.h" + +class TestProfilerManager : public benchmark::ProfilerManager { + public: + void AfterSetupStart() override { ++start_called; } + void BeforeTeardownStop() override { ++stop_called; } + + int start_called = 0; + int stop_called = 0; +}; + +void BM_empty(benchmark::State& state) { + for (auto _ : state) { + auto iterations = state.iterations(); + benchmark::DoNotOptimize(iterations); + } +} +BENCHMARK(BM_empty); + +ADD_CASES(TC_ConsoleOut, {{"^BM_empty %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_empty\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_empty\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_empty\",%csv_report$"}}); + +int main(int argc, char* argv[]) { + std::unique_ptr pm(new TestProfilerManager()); + + benchmark::RegisterProfilerManager(pm.get()); + RunOutputTests(argc, argv); + benchmark::RegisterProfilerManager(nullptr); + + assert(pm->start_called == 1); + assert(pm->stop_called == 1); +} diff --git a/third_party/benchmark/test/register_benchmark_test.cc b/third_party/benchmark/test/register_benchmark_test.cc new file mode 100644 index 0000000..d69d144 --- /dev/null +++ b/third_party/benchmark/test/register_benchmark_test.cc @@ -0,0 +1,196 @@ + +#undef NDEBUG +#include +#include + +#include "../src/check.h" // NOTE: check.h is for internal use only! +#include "benchmark/benchmark.h" + +namespace { + +class TestReporter : public benchmark::ConsoleReporter { + public: + void ReportRuns(const std::vector& report) override { + all_runs_.insert(all_runs_.end(), begin(report), end(report)); + ConsoleReporter::ReportRuns(report); + } + + std::vector all_runs_; +}; + +struct TestCase { + const std::string name; + const std::string label; + // Note: not explicit as we rely on it being converted through ADD_CASES. + TestCase(const std::string& xname) : TestCase(xname, "") {} + TestCase(const std::string& xname, const std::string& xlabel) + : name(xname), label(xlabel) {} + + typedef benchmark::BenchmarkReporter::Run Run; + + void CheckRun(Run const& run) const { + // clang-format off + BM_CHECK(name == run.benchmark_name()) << "expected " << name << " got " + << run.benchmark_name(); + if (!label.empty()) { + BM_CHECK(run.report_label == label) << "expected " << label << " got " + << run.report_label; + } else { + BM_CHECK(run.report_label.empty()); + } + // clang-format on + } +}; + +std::vector ExpectedResults; + +int AddCases(std::initializer_list const& v) { + for (const auto& N : v) { + ExpectedResults.push_back(N); + } + return 0; +} + +#define CONCAT(x, y) CONCAT2(x, y) +#define CONCAT2(x, y) x##y +#define ADD_CASES(...) int CONCAT(dummy, __LINE__) = AddCases({__VA_ARGS__}) + +} // end namespace + +typedef benchmark::internal::Benchmark* ReturnVal; + +//----------------------------------------------------------------------------// +// Test RegisterBenchmark with no additional arguments +//----------------------------------------------------------------------------// +void BM_function(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_function); +ReturnVal dummy = benchmark::RegisterBenchmark( + "BM_function_manual_registration", BM_function); +ADD_CASES({"BM_function"}, {"BM_function_manual_registration"}); + +//----------------------------------------------------------------------------// +// Test RegisterBenchmark with additional arguments +// Note: GCC <= 4.8 do not support this form of RegisterBenchmark because they +// reject the variadic pack expansion of lambda captures. +//----------------------------------------------------------------------------// +#ifndef BENCHMARK_HAS_NO_VARIADIC_REGISTER_BENCHMARK + +void BM_extra_args(benchmark::State& st, const char* label) { + for (auto _ : st) { + } + st.SetLabel(label); +} +int RegisterFromFunction() { + std::pair cases[] = { + {"test1", "One"}, {"test2", "Two"}, {"test3", "Three"}}; + for (auto const& c : cases) + benchmark::RegisterBenchmark(c.first, &BM_extra_args, c.second); + return 0; +} +int dummy2 = RegisterFromFunction(); +ADD_CASES({"test1", "One"}, {"test2", "Two"}, {"test3", "Three"}); + +#endif // BENCHMARK_HAS_NO_VARIADIC_REGISTER_BENCHMARK + +//----------------------------------------------------------------------------// +// Test RegisterBenchmark with DISABLED_ benchmark +//----------------------------------------------------------------------------// +void DISABLED_BM_function(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(DISABLED_BM_function); +ReturnVal dummy3 = benchmark::RegisterBenchmark("DISABLED_BM_function_manual", + DISABLED_BM_function); +// No need to add cases because we don't expect them to run. + +//----------------------------------------------------------------------------// +// Test RegisterBenchmark with different callable types +//----------------------------------------------------------------------------// + +struct CustomFixture { + void operator()(benchmark::State& st) { + for (auto _ : st) { + } + } +}; + +void TestRegistrationAtRuntime() { +#ifdef BENCHMARK_HAS_CXX11 + { + CustomFixture fx; + benchmark::RegisterBenchmark("custom_fixture", fx); + AddCases({std::string("custom_fixture")}); + } +#endif +#ifndef BENCHMARK_HAS_NO_VARIADIC_REGISTER_BENCHMARK + { + const char* x = "42"; + auto capturing_lam = [=](benchmark::State& st) { + for (auto _ : st) { + } + st.SetLabel(x); + }; + benchmark::RegisterBenchmark("lambda_benchmark", capturing_lam); + AddCases({{"lambda_benchmark", x}}); + } +#endif +} + +// Test that all benchmarks, registered at either during static init or runtime, +// are run and the results are passed to the reported. +void RunTestOne() { + TestRegistrationAtRuntime(); + + TestReporter test_reporter; + benchmark::RunSpecifiedBenchmarks(&test_reporter); + + typedef benchmark::BenchmarkReporter::Run Run; + auto EB = ExpectedResults.begin(); + + for (Run const& run : test_reporter.all_runs_) { + assert(EB != ExpectedResults.end()); + EB->CheckRun(run); + ++EB; + } + assert(EB == ExpectedResults.end()); +} + +// Test that ClearRegisteredBenchmarks() clears all previously registered +// benchmarks. +// Also test that new benchmarks can be registered and ran afterwards. +void RunTestTwo() { + assert(ExpectedResults.size() != 0 && + "must have at least one registered benchmark"); + ExpectedResults.clear(); + benchmark::ClearRegisteredBenchmarks(); + + TestReporter test_reporter; + size_t num_ran = benchmark::RunSpecifiedBenchmarks(&test_reporter); + assert(num_ran == 0); + assert(test_reporter.all_runs_.begin() == test_reporter.all_runs_.end()); + + TestRegistrationAtRuntime(); + num_ran = benchmark::RunSpecifiedBenchmarks(&test_reporter); + assert(num_ran == ExpectedResults.size()); + + typedef benchmark::BenchmarkReporter::Run Run; + auto EB = ExpectedResults.begin(); + + for (Run const& run : test_reporter.all_runs_) { + assert(EB != ExpectedResults.end()); + EB->CheckRun(run); + ++EB; + } + assert(EB == ExpectedResults.end()); +} + +int main(int argc, char* argv[]) { + benchmark::Initialize(&argc, argv); + + RunTestOne(); + RunTestTwo(); +} diff --git a/third_party/benchmark/test/repetitions_test.cc b/third_party/benchmark/test/repetitions_test.cc new file mode 100644 index 0000000..569777d --- /dev/null +++ b/third_party/benchmark/test/repetitions_test.cc @@ -0,0 +1,214 @@ + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// ========================================================================= // +// ------------------------ Testing Basic Output --------------------------- // +// ========================================================================= // + +static void BM_ExplicitRepetitions(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_ExplicitRepetitions)->Repetitions(2); + +ADD_CASES(TC_ConsoleOut, + {{"^BM_ExplicitRepetitions/repeats:2 %console_report$"}}); +ADD_CASES(TC_ConsoleOut, + {{"^BM_ExplicitRepetitions/repeats:2 %console_report$"}}); +ADD_CASES(TC_ConsoleOut, + {{"^BM_ExplicitRepetitions/repeats:2_mean %console_report$"}}); +ADD_CASES(TC_ConsoleOut, + {{"^BM_ExplicitRepetitions/repeats:2_median %console_report$"}}); +ADD_CASES(TC_ConsoleOut, + {{"^BM_ExplicitRepetitions/repeats:2_stddev %console_report$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_ExplicitRepetitions/repeats:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ExplicitRepetitions/repeats:2\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_ExplicitRepetitions/repeats:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ExplicitRepetitions/repeats:2\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_ExplicitRepetitions/repeats:2_mean\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ExplicitRepetitions/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_ExplicitRepetitions/repeats:2_median\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ExplicitRepetitions/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_ExplicitRepetitions/repeats:2_stddev\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ExplicitRepetitions/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ExplicitRepetitions/repeats:2\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ExplicitRepetitions/repeats:2\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_ExplicitRepetitions/repeats:2_mean\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_ExplicitRepetitions/repeats:2_median\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_ExplicitRepetitions/repeats:2_stddev\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Basic Output --------------------------- // +// ========================================================================= // + +static void BM_ImplicitRepetitions(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_ImplicitRepetitions); + +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions %console_report$"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions %console_report$"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions %console_report$"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions_mean %console_report$"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions_median %console_report$"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_ImplicitRepetitions_stddev %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions_mean\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions_median\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_ImplicitRepetitions_stddev\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_ImplicitRepetitions\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ImplicitRepetitions\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ImplicitRepetitions\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ImplicitRepetitions_mean\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ImplicitRepetitions_median\",%csv_report$"}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_ImplicitRepetitions_stddev\",%csv_report$"}}); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/report_aggregates_only_test.cc b/third_party/benchmark/test/report_aggregates_only_test.cc new file mode 100644 index 0000000..47da503 --- /dev/null +++ b/third_party/benchmark/test/report_aggregates_only_test.cc @@ -0,0 +1,41 @@ + +#undef NDEBUG +#include +#include + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// Ok this test is super ugly. We want to check what happens with the file +// reporter in the presence of ReportAggregatesOnly(). +// We do not care about console output, the normal tests check that already. + +void BM_SummaryRepeat(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_SummaryRepeat)->Repetitions(3)->ReportAggregatesOnly(); + +int main(int argc, char* argv[]) { + const std::string output = GetFileReporterOutput(argc, argv); + + if (SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3") != 4 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_mean\"") != 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_median\"") != + 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_stddev\"") != + 1 || + SubstrCnt(output, "\"name\": \"BM_SummaryRepeat/repeats:3_cv\"") != 1) { + std::cout << "Precondition mismatch. Expected to only find four " + "occurrences of \"BM_SummaryRepeat/repeats:3\" substring:\n" + "\"name\": \"BM_SummaryRepeat/repeats:3_mean\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_median\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_stddev\", " + "\"name\": \"BM_SummaryRepeat/repeats:3_cv\"\nThe entire " + "output:\n"; + std::cout << output; + return 1; + } + + return 0; +} diff --git a/third_party/benchmark/test/reporter_output_test.cc b/third_party/benchmark/test/reporter_output_test.cc new file mode 100644 index 0000000..7867165 --- /dev/null +++ b/third_party/benchmark/test/reporter_output_test.cc @@ -0,0 +1,1133 @@ + +#undef NDEBUG +#include +#include + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// ========================================================================= // +// ---------------------- Testing Prologue Output -------------------------- // +// ========================================================================= // + +ADD_CASES(TC_ConsoleOut, {{"^[-]+$", MR_Next}, + {"^Benchmark %s Time %s CPU %s Iterations$", MR_Next}, + {"^[-]+$", MR_Next}}); +static int AddContextCases() { + AddCases(TC_ConsoleErr, + { + {"^%int-%int-%intT%int:%int:%int[-+]%int:%int$", MR_Default}, + {"Running .*(/|\\\\)reporter_output_test(\\.exe)?$", MR_Next}, + {"Run on \\(%int X %float MHz CPU s?\\)", MR_Next}, + }); + AddCases(TC_JSONOut, + {{"^\\{", MR_Default}, + {"\"context\":", MR_Next}, + {"\"date\": \"", MR_Next}, + {"\"host_name\":", MR_Next}, + {"\"executable\": \".*(/|\\\\)reporter_output_test(\\.exe)?\",", + MR_Next}, + {"\"num_cpus\": %int,$", MR_Next}, + {"\"mhz_per_cpu\": %float,$", MR_Next}, + {"\"caches\": \\[$", MR_Default}}); + auto const& Info = benchmark::CPUInfo::Get(); + auto const& Caches = Info.caches; + if (!Caches.empty()) { + AddCases(TC_ConsoleErr, {{"CPU Caches:$", MR_Next}}); + } + for (size_t I = 0; I < Caches.size(); ++I) { + std::string num_caches_str = + Caches[I].num_sharing != 0 ? " \\(x%int\\)$" : "$"; + AddCases(TC_ConsoleErr, + {{"L%int (Data|Instruction|Unified) %int KiB" + num_caches_str, + MR_Next}}); + AddCases(TC_JSONOut, {{"\\{$", MR_Next}, + {"\"type\": \"", MR_Next}, + {"\"level\": %int,$", MR_Next}, + {"\"size\": %int,$", MR_Next}, + {"\"num_sharing\": %int$", MR_Next}, + {"}[,]{0,1}$", MR_Next}}); + } + AddCases(TC_JSONOut, {{"],$"}}); + auto const& LoadAvg = Info.load_avg; + if (!LoadAvg.empty()) { + AddCases(TC_ConsoleErr, + {{"Load Average: (%float, ){0,2}%float$", MR_Next}}); + } + AddCases(TC_JSONOut, {{"\"load_avg\": \\[(%float,?){0,3}],$", MR_Next}}); + AddCases(TC_JSONOut, {{"\"library_version\": \".*\",$", MR_Next}}); + AddCases(TC_JSONOut, {{"\"library_build_type\": \".*\",$", MR_Next}}); + AddCases(TC_JSONOut, {{"\"json_schema_version\": 1$", MR_Next}}); + return 0; +} +int dummy_register = AddContextCases(); +ADD_CASES(TC_CSVOut, {{"%csv_header"}}); + +// ========================================================================= // +// ------------------------ Testing Basic Output --------------------------- // +// ========================================================================= // + +void BM_basic(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_basic); + +ADD_CASES(TC_ConsoleOut, {{"^BM_basic %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_basic\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_basic\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_basic\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Bytes per Second Output ---------------- // +// ========================================================================= // + +void BM_bytes_per_second(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + state.SetBytesProcessed(1); +} +BENCHMARK(BM_bytes_per_second); + +ADD_CASES(TC_ConsoleOut, {{"^BM_bytes_per_second %console_report " + "bytes_per_second=%float[kM]{0,1}/s$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_bytes_per_second\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_bytes_per_second\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bytes_per_second\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_bytes_per_second\",%csv_bytes_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Items per Second Output ---------------- // +// ========================================================================= // + +void BM_items_per_second(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + state.SetItemsProcessed(1); +} +BENCHMARK(BM_items_per_second); + +ADD_CASES(TC_ConsoleOut, {{"^BM_items_per_second %console_report " + "items_per_second=%float[kM]{0,1}/s$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_items_per_second\",$"}, + {"\"family_index\": 2,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_items_per_second\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"items_per_second\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_items_per_second\",%csv_items_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Label Output --------------------------- // +// ========================================================================= // + +void BM_label(benchmark::State& state) { + for (auto _ : state) { + } + state.SetLabel("some label"); +} +BENCHMARK(BM_label); + +ADD_CASES(TC_ConsoleOut, {{"^BM_label %console_report some label$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_label\",$"}, + {"\"family_index\": 3,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_label\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"label\": \"some label\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_label\",%csv_label_report_begin\"some " + "label\"%csv_label_report_end$"}}); + +// ========================================================================= // +// ------------------------ Testing Time Label Output ---------------------- // +// ========================================================================= // + +void BM_time_label_nanosecond(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_time_label_nanosecond)->Unit(benchmark::kNanosecond); + +ADD_CASES(TC_ConsoleOut, {{"^BM_time_label_nanosecond %console_report$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_time_label_nanosecond\",$"}, + {"\"family_index\": 4,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_time_label_nanosecond\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_time_label_nanosecond\",%csv_report$"}}); + +void BM_time_label_microsecond(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_time_label_microsecond)->Unit(benchmark::kMicrosecond); + +ADD_CASES(TC_ConsoleOut, {{"^BM_time_label_microsecond %console_us_report$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_time_label_microsecond\",$"}, + {"\"family_index\": 5,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_time_label_microsecond\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"us\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_time_label_microsecond\",%csv_us_report$"}}); + +void BM_time_label_millisecond(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_time_label_millisecond)->Unit(benchmark::kMillisecond); + +ADD_CASES(TC_ConsoleOut, {{"^BM_time_label_millisecond %console_ms_report$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_time_label_millisecond\",$"}, + {"\"family_index\": 6,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_time_label_millisecond\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ms\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_time_label_millisecond\",%csv_ms_report$"}}); + +void BM_time_label_second(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_time_label_second)->Unit(benchmark::kSecond); + +ADD_CASES(TC_ConsoleOut, {{"^BM_time_label_second %console_s_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_time_label_second\",$"}, + {"\"family_index\": 7,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_time_label_second\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"s\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_time_label_second\",%csv_s_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Error Output --------------------------- // +// ========================================================================= // + +void BM_error(benchmark::State& state) { + state.SkipWithError("message"); + for (auto _ : state) { + } +} +BENCHMARK(BM_error); +ADD_CASES(TC_ConsoleOut, {{"^BM_error[ ]+ERROR OCCURRED: 'message'$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_error\",$"}, + {"\"family_index\": 8,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_error\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"error_occurred\": true,$", MR_Next}, + {"\"error_message\": \"message\",$", MR_Next}}); + +ADD_CASES(TC_CSVOut, {{"^\"BM_error\",,,,,,,,true,\"message\"$"}}); + +// ========================================================================= // +// ------------------------ Testing No Arg Name Output ----------------------- +// // +// ========================================================================= // + +void BM_no_arg_name(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_no_arg_name)->Arg(3); +ADD_CASES(TC_ConsoleOut, {{"^BM_no_arg_name/3 %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_no_arg_name/3\",$"}, + {"\"family_index\": 9,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_no_arg_name/3\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_no_arg_name/3\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Arg Name Output ------------------------ // +// ========================================================================= // + +void BM_arg_name(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_arg_name)->ArgName("first")->Arg(3); +ADD_CASES(TC_ConsoleOut, {{"^BM_arg_name/first:3 %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_arg_name/first:3\",$"}, + {"\"family_index\": 10,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_arg_name/first:3\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_arg_name/first:3\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Arg Names Output ----------------------- // +// ========================================================================= // + +void BM_arg_names(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_arg_names)->Args({2, 5, 4})->ArgNames({"first", "", "third"}); +ADD_CASES(TC_ConsoleOut, + {{"^BM_arg_names/first:2/5/third:4 %console_report$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_arg_names/first:2/5/third:4\",$"}, + {"\"family_index\": 11,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_arg_names/first:2/5/third:4\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_arg_names/first:2/5/third:4\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Name Output ---------------------------- // +// ========================================================================= // + +void BM_name(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_name)->Name("BM_custom_name"); + +ADD_CASES(TC_ConsoleOut, {{"^BM_custom_name %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_custom_name\",$"}, + {"\"family_index\": 12,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_custom_name\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\"$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_custom_name\",%csv_report$"}}); + +// ========================================================================= // +// ------------------------ Testing Big Args Output ------------------------ // +// ========================================================================= // + +void BM_BigArgs(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_BigArgs)->RangeMultiplier(2)->Range(1U << 30U, 1U << 31U); +ADD_CASES(TC_ConsoleOut, {{"^BM_BigArgs/1073741824 %console_report$"}, + {"^BM_BigArgs/2147483648 %console_report$"}}); + +// ========================================================================= // +// ----------------------- Testing Complexity Output ----------------------- // +// ========================================================================= // + +void BM_Complexity_O1(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + state.SetComplexityN(state.range(0)); +} +BENCHMARK(BM_Complexity_O1)->Range(1, 1 << 18)->Complexity(benchmark::o1); +SET_SUBSTITUTIONS({{"%bigOStr", "[ ]* %float \\([0-9]+\\)"}, + {"%RMS", "[ ]*[0-9]+ %"}}); +ADD_CASES(TC_ConsoleOut, {{"^BM_Complexity_O1_BigO %bigOStr %bigOStr[ ]*$"}, + {"^BM_Complexity_O1_RMS %RMS %RMS[ ]*$"}}); + +// ========================================================================= // +// ----------------------- Testing Aggregate Output ------------------------ // +// ========================================================================= // + +// Test that non-aggregate data is printed by default +void BM_Repeat(benchmark::State& state) { + for (auto _ : state) { + } +} +// need two repetitions min to be able to output any aggregate output +BENCHMARK(BM_Repeat)->Repetitions(2); +ADD_CASES(TC_ConsoleOut, + {{"^BM_Repeat/repeats:2 %console_report$"}, + {"^BM_Repeat/repeats:2 %console_report$"}, + {"^BM_Repeat/repeats:2_mean %console_time_only_report [ ]*2$"}, + {"^BM_Repeat/repeats:2_median %console_time_only_report [ ]*2$"}, + {"^BM_Repeat/repeats:2_stddev %console_time_only_report [ ]*2$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Repeat/repeats:2\",$"}, + {"\"family_index\": 15,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:2\"", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:2\",$"}, + {"\"family_index\": 15,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:2\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:2_mean\",$"}, + {"\"family_index\": 15,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:2_median\",$"}, + {"\"family_index\": 15,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:2_stddev\",$"}, + {"\"family_index\": 15,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Repeat/repeats:2\",%csv_report$"}, + {"^\"BM_Repeat/repeats:2\",%csv_report$"}, + {"^\"BM_Repeat/repeats:2_mean\",%csv_report$"}, + {"^\"BM_Repeat/repeats:2_median\",%csv_report$"}, + {"^\"BM_Repeat/repeats:2_stddev\",%csv_report$"}}); +// but for two repetitions, mean and median is the same, so let's repeat.. +BENCHMARK(BM_Repeat)->Repetitions(3); +ADD_CASES(TC_ConsoleOut, + {{"^BM_Repeat/repeats:3 %console_report$"}, + {"^BM_Repeat/repeats:3 %console_report$"}, + {"^BM_Repeat/repeats:3 %console_report$"}, + {"^BM_Repeat/repeats:3_mean %console_time_only_report [ ]*3$"}, + {"^BM_Repeat/repeats:3_median %console_time_only_report [ ]*3$"}, + {"^BM_Repeat/repeats:3_stddev %console_time_only_report [ ]*3$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Repeat/repeats:3\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:3\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:3\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:3_mean\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:3_median\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:3_stddev\",$"}, + {"\"family_index\": 16,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Repeat/repeats:3\",%csv_report$"}, + {"^\"BM_Repeat/repeats:3\",%csv_report$"}, + {"^\"BM_Repeat/repeats:3\",%csv_report$"}, + {"^\"BM_Repeat/repeats:3_mean\",%csv_report$"}, + {"^\"BM_Repeat/repeats:3_median\",%csv_report$"}, + {"^\"BM_Repeat/repeats:3_stddev\",%csv_report$"}}); +// median differs between even/odd number of repetitions, so just to be sure +BENCHMARK(BM_Repeat)->Repetitions(4); +ADD_CASES(TC_ConsoleOut, + {{"^BM_Repeat/repeats:4 %console_report$"}, + {"^BM_Repeat/repeats:4 %console_report$"}, + {"^BM_Repeat/repeats:4 %console_report$"}, + {"^BM_Repeat/repeats:4 %console_report$"}, + {"^BM_Repeat/repeats:4_mean %console_time_only_report [ ]*4$"}, + {"^BM_Repeat/repeats:4_median %console_time_only_report [ ]*4$"}, + {"^BM_Repeat/repeats:4_stddev %console_time_only_report [ ]*4$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Repeat/repeats:4\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"repetition_index\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"repetition_index\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4_mean\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 4,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4_median\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 4,$", MR_Next}, + {"\"name\": \"BM_Repeat/repeats:4_stddev\",$"}, + {"\"family_index\": 17,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Repeat/repeats:4\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 4,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 4,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Repeat/repeats:4\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4_mean\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4_median\",%csv_report$"}, + {"^\"BM_Repeat/repeats:4_stddev\",%csv_report$"}}); + +// Test that a non-repeated test still prints non-aggregate results even when +// only-aggregate reports have been requested +void BM_RepeatOnce(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_RepeatOnce)->Repetitions(1)->ReportAggregatesOnly(); +ADD_CASES(TC_ConsoleOut, {{"^BM_RepeatOnce/repeats:1 %console_report$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_RepeatOnce/repeats:1\",$"}, + {"\"family_index\": 18,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_RepeatOnce/repeats:1\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_RepeatOnce/repeats:1\",%csv_report$"}}); + +// Test that non-aggregate data is not reported +void BM_SummaryRepeat(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_SummaryRepeat)->Repetitions(3)->ReportAggregatesOnly(); +ADD_CASES( + TC_ConsoleOut, + {{".*BM_SummaryRepeat/repeats:3 ", MR_Not}, + {"^BM_SummaryRepeat/repeats:3_mean %console_time_only_report [ ]*3$"}, + {"^BM_SummaryRepeat/repeats:3_median %console_time_only_report [ ]*3$"}, + {"^BM_SummaryRepeat/repeats:3_stddev %console_time_only_report [ ]*3$"}}); +ADD_CASES(TC_JSONOut, + {{".*BM_SummaryRepeat/repeats:3 ", MR_Not}, + {"\"name\": \"BM_SummaryRepeat/repeats:3_mean\",$"}, + {"\"family_index\": 19,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryRepeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"name\": \"BM_SummaryRepeat/repeats:3_median\",$"}, + {"\"family_index\": 19,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryRepeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"name\": \"BM_SummaryRepeat/repeats:3_stddev\",$"}, + {"\"family_index\": 19,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryRepeat/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{".*BM_SummaryRepeat/repeats:3 ", MR_Not}, + {"^\"BM_SummaryRepeat/repeats:3_mean\",%csv_report$"}, + {"^\"BM_SummaryRepeat/repeats:3_median\",%csv_report$"}, + {"^\"BM_SummaryRepeat/repeats:3_stddev\",%csv_report$"}}); + +// Test that non-aggregate data is not displayed. +// NOTE: this test is kinda bad. we are only testing the display output. +// But we don't check that the file output still contains everything... +void BM_SummaryDisplay(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_SummaryDisplay)->Repetitions(2)->DisplayAggregatesOnly(); +ADD_CASES( + TC_ConsoleOut, + {{".*BM_SummaryDisplay/repeats:2 ", MR_Not}, + {"^BM_SummaryDisplay/repeats:2_mean %console_time_only_report [ ]*2$"}, + {"^BM_SummaryDisplay/repeats:2_median %console_time_only_report [ ]*2$"}, + {"^BM_SummaryDisplay/repeats:2_stddev %console_time_only_report [ ]*2$"}}); +ADD_CASES(TC_JSONOut, + {{".*BM_SummaryDisplay/repeats:2 ", MR_Not}, + {"\"name\": \"BM_SummaryDisplay/repeats:2_mean\",$"}, + {"\"family_index\": 20,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryDisplay/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"name\": \"BM_SummaryDisplay/repeats:2_median\",$"}, + {"\"family_index\": 20,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryDisplay/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"name\": \"BM_SummaryDisplay/repeats:2_stddev\",$"}, + {"\"family_index\": 20,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_SummaryDisplay/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}}); +ADD_CASES(TC_CSVOut, + {{".*BM_SummaryDisplay/repeats:2 ", MR_Not}, + {"^\"BM_SummaryDisplay/repeats:2_mean\",%csv_report$"}, + {"^\"BM_SummaryDisplay/repeats:2_median\",%csv_report$"}, + {"^\"BM_SummaryDisplay/repeats:2_stddev\",%csv_report$"}}); + +// Test repeats with custom time unit. +void BM_RepeatTimeUnit(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_RepeatTimeUnit) + ->Repetitions(3) + ->ReportAggregatesOnly() + ->Unit(benchmark::kMicrosecond); +ADD_CASES( + TC_ConsoleOut, + {{".*BM_RepeatTimeUnit/repeats:3 ", MR_Not}, + {"^BM_RepeatTimeUnit/repeats:3_mean %console_us_time_only_report [ ]*3$"}, + {"^BM_RepeatTimeUnit/repeats:3_median %console_us_time_only_report [ " + "]*3$"}, + {"^BM_RepeatTimeUnit/repeats:3_stddev %console_us_time_only_report [ " + "]*3$"}}); +ADD_CASES(TC_JSONOut, + {{".*BM_RepeatTimeUnit/repeats:3 ", MR_Not}, + {"\"name\": \"BM_RepeatTimeUnit/repeats:3_mean\",$"}, + {"\"family_index\": 21,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_RepeatTimeUnit/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"time_unit\": \"us\",?$"}, + {"\"name\": \"BM_RepeatTimeUnit/repeats:3_median\",$"}, + {"\"family_index\": 21,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_RepeatTimeUnit/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"time_unit\": \"us\",?$"}, + {"\"name\": \"BM_RepeatTimeUnit/repeats:3_stddev\",$"}, + {"\"family_index\": 21,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_RepeatTimeUnit/repeats:3\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"time_unit\": \"us\",?$"}}); +ADD_CASES(TC_CSVOut, + {{".*BM_RepeatTimeUnit/repeats:3 ", MR_Not}, + {"^\"BM_RepeatTimeUnit/repeats:3_mean\",%csv_us_report$"}, + {"^\"BM_RepeatTimeUnit/repeats:3_median\",%csv_us_report$"}, + {"^\"BM_RepeatTimeUnit/repeats:3_stddev\",%csv_us_report$"}}); + +// ========================================================================= // +// -------------------- Testing user-provided statistics ------------------- // +// ========================================================================= // + +const auto UserStatistics = [](const std::vector& v) { + return v.back(); +}; +void BM_UserStats(benchmark::State& state) { + for (auto _ : state) { + state.SetIterationTime(150 / 10e8); + } +} +// clang-format off +BENCHMARK(BM_UserStats) + ->Repetitions(3) + ->Iterations(5) + ->UseManualTime() + ->ComputeStatistics("", UserStatistics); +// clang-format on + +// check that user-provided stats is calculated, and is after the default-ones +// empty string as name is intentional, it would sort before anything else +ADD_CASES(TC_ConsoleOut, {{"^BM_UserStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserStats/iterations:5/repeats:3/" + "manual_time_mean [ ]* 150 ns %time [ ]*3$"}, + {"^BM_UserStats/iterations:5/repeats:3/" + "manual_time_median [ ]* 150 ns %time [ ]*3$"}, + {"^BM_UserStats/iterations:5/repeats:3/" + "manual_time_stddev [ ]* 0.000 ns %time [ ]*3$"}, + {"^BM_UserStats/iterations:5/repeats:3/manual_time_ " + "[ ]* 150 ns %time [ ]*3$"}}); +ADD_CASES( + TC_JSONOut, + {{"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time_mean\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time_median\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time_stddev\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"name\": \"BM_UserStats/iterations:5/repeats:3/manual_time_\",$"}, + {"\"family_index\": 22,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_UserStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}}); +ADD_CASES( + TC_CSVOut, + {{"^\"BM_UserStats/iterations:5/repeats:3/manual_time\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/manual_time\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/manual_time\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/manual_time_mean\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/" + "manual_time_median\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/" + "manual_time_stddev\",%csv_report$"}, + {"^\"BM_UserStats/iterations:5/repeats:3/manual_time_\",%csv_report$"}}); + +// ========================================================================= // +// ------------- Testing relative standard deviation statistics ------------ // +// ========================================================================= // + +const auto UserPercentStatistics = [](const std::vector&) { + return 1. / 100.; +}; +void BM_UserPercentStats(benchmark::State& state) { + for (auto _ : state) { + state.SetIterationTime(150 / 10e8); + } +} +// clang-format off +BENCHMARK(BM_UserPercentStats) + ->Repetitions(3) + ->Iterations(5) + ->UseManualTime() + ->Unit(benchmark::TimeUnit::kNanosecond) + ->ComputeStatistics("", UserPercentStatistics, benchmark::StatisticUnit::kPercentage); +// clang-format on + +// check that UserPercent-provided stats is calculated, and is after the +// default-ones empty string as name is intentional, it would sort before +// anything else +ADD_CASES(TC_ConsoleOut, + {{"^BM_UserPercentStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/manual_time [ " + "]* 150 ns %time [ ]*5$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_mean [ ]* 150 ns %time [ ]*3$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_median [ ]* 150 ns %time [ ]*3$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_stddev [ ]* 0.000 ns %time [ ]*3$"}, + {"^BM_UserPercentStats/iterations:5/repeats:3/manual_time_ " + "[ ]* 1.00 % [ ]* 1.00 %[ ]*3$"}}); +ADD_CASES( + TC_JSONOut, + {{"\"name\": \"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": \"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"repetition_index\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": 5,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time_mean\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time_median\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.5(0)*e\\+(0)*2,$", MR_Next}, + {"\"name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time_stddev\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time_\",$"}, + {"\"family_index\": 23,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": " + "\"BM_UserPercentStats/iterations:5/repeats:3/manual_time\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 3,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"\",$", MR_Next}, + {"\"aggregate_unit\": \"percentage\",$", MR_Next}, + {"\"iterations\": 3,$", MR_Next}, + {"\"real_time\": 1\\.(0)*e-(0)*2,$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_mean\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_median\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_stddev\",%csv_report$"}, + {"^\"BM_UserPercentStats/iterations:5/repeats:3/" + "manual_time_\",%csv_cv_report$"}}); + +// ========================================================================= // +// ------------------------- Testing StrEscape JSON ------------------------ // +// ========================================================================= // +#if 0 // enable when csv testing code correctly handles multi-line fields +void BM_JSON_Format(benchmark::State& state) { + state.SkipWithError("val\b\f\n\r\t\\\"with\"es,capes"); + for (auto _ : state) { + } +} +BENCHMARK(BM_JSON_Format); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_JSON_Format\",$"}, + {"\"family_index\": 23,$", MR_Next}, +{"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_JSON_Format\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"error_occurred\": true,$", MR_Next}, + {R"("error_message": "val\\b\\f\\n\\r\\t\\\\\\"with\\"es,capes",$)", MR_Next}}); +#endif +// ========================================================================= // +// -------------------------- Testing CsvEscape ---------------------------- // +// ========================================================================= // + +void BM_CSV_Format(benchmark::State& state) { + state.SkipWithError("\"freedom\""); + for (auto _ : state) { + } +} +BENCHMARK(BM_CSV_Format); +ADD_CASES(TC_CSVOut, {{"^\"BM_CSV_Format\",,,,,,,,true,\"\"\"freedom\"\"\"$"}}); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/skip_with_error_test.cc b/third_party/benchmark/test/skip_with_error_test.cc new file mode 100644 index 0000000..2139a19 --- /dev/null +++ b/third_party/benchmark/test/skip_with_error_test.cc @@ -0,0 +1,199 @@ + +#undef NDEBUG +#include +#include + +#include "../src/check.h" // NOTE: check.h is for internal use only! +#include "benchmark/benchmark.h" + +namespace { + +class TestReporter : public benchmark::ConsoleReporter { + public: + bool ReportContext(const Context& context) override { + return ConsoleReporter::ReportContext(context); + }; + + void ReportRuns(const std::vector& report) override { + all_runs_.insert(all_runs_.end(), begin(report), end(report)); + ConsoleReporter::ReportRuns(report); + } + + TestReporter() {} + ~TestReporter() override {} + + mutable std::vector all_runs_; +}; + +struct TestCase { + std::string name; + bool error_occurred; + std::string error_message; + + typedef benchmark::BenchmarkReporter::Run Run; + + void CheckRun(Run const& run) const { + BM_CHECK(name == run.benchmark_name()) + << "expected " << name << " got " << run.benchmark_name(); + BM_CHECK_EQ(error_occurred, + benchmark::internal::SkippedWithError == run.skipped); + BM_CHECK(error_message == run.skip_message); + if (error_occurred) { + // BM_CHECK(run.iterations == 0); + } else { + BM_CHECK(run.iterations != 0); + } + } +}; + +std::vector ExpectedResults; + +int AddCases(const std::string& base_name, + std::initializer_list const& v) { + for (auto TC : v) { + TC.name = base_name + TC.name; + ExpectedResults.push_back(std::move(TC)); + } + return 0; +} + +#define CONCAT(x, y) CONCAT2(x, y) +#define CONCAT2(x, y) x##y +#define ADD_CASES(...) int CONCAT(dummy, __LINE__) = AddCases(__VA_ARGS__) + +} // end namespace + +void BM_error_no_running(benchmark::State& state) { + state.SkipWithError("error message"); +} +BENCHMARK(BM_error_no_running); +ADD_CASES("BM_error_no_running", {{"", true, "error message"}}); + +void BM_error_before_running(benchmark::State& state) { + state.SkipWithError("error message"); + while (state.KeepRunning()) { + assert(false); + } +} +BENCHMARK(BM_error_before_running); +ADD_CASES("BM_error_before_running", {{"", true, "error message"}}); + +void BM_error_before_running_batch(benchmark::State& state) { + state.SkipWithError("error message"); + while (state.KeepRunningBatch(17)) { + assert(false); + } +} +BENCHMARK(BM_error_before_running_batch); +ADD_CASES("BM_error_before_running_batch", {{"", true, "error message"}}); + +void BM_error_before_running_range_for(benchmark::State& state) { + state.SkipWithError("error message"); + for (auto _ : state) { + assert(false); + } +} +BENCHMARK(BM_error_before_running_range_for); +ADD_CASES("BM_error_before_running_range_for", {{"", true, "error message"}}); + +void BM_error_during_running(benchmark::State& state) { + int first_iter = true; + while (state.KeepRunning()) { + if (state.range(0) == 1 && state.thread_index() <= (state.threads() / 2)) { + assert(first_iter); + first_iter = false; + state.SkipWithError("error message"); + } else { + state.PauseTiming(); + state.ResumeTiming(); + } + } +} +BENCHMARK(BM_error_during_running)->Arg(1)->Arg(2)->ThreadRange(1, 8); +ADD_CASES("BM_error_during_running", {{"/1/threads:1", true, "error message"}, + {"/1/threads:2", true, "error message"}, + {"/1/threads:4", true, "error message"}, + {"/1/threads:8", true, "error message"}, + {"/2/threads:1", false, ""}, + {"/2/threads:2", false, ""}, + {"/2/threads:4", false, ""}, + {"/2/threads:8", false, ""}}); + +void BM_error_during_running_ranged_for(benchmark::State& state) { + assert(state.max_iterations > 3 && "test requires at least a few iterations"); + bool first_iter = true; + // NOTE: Users should not write the for loop explicitly. + for (auto It = state.begin(), End = state.end(); It != End; ++It) { + if (state.range(0) == 1) { + assert(first_iter); + first_iter = false; + (void)first_iter; + state.SkipWithError("error message"); + // Test the unfortunate but documented behavior that the ranged-for loop + // doesn't automatically terminate when SkipWithError is set. + assert(++It != End); + break; // Required behavior + } + } +} +BENCHMARK(BM_error_during_running_ranged_for)->Arg(1)->Arg(2)->Iterations(5); +ADD_CASES("BM_error_during_running_ranged_for", + {{"/1/iterations:5", true, "error message"}, + {"/2/iterations:5", false, ""}}); + +void BM_error_after_running(benchmark::State& state) { + for (auto _ : state) { + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + if (state.thread_index() <= (state.threads() / 2)) + state.SkipWithError("error message"); +} +BENCHMARK(BM_error_after_running)->ThreadRange(1, 8); +ADD_CASES("BM_error_after_running", {{"/threads:1", true, "error message"}, + {"/threads:2", true, "error message"}, + {"/threads:4", true, "error message"}, + {"/threads:8", true, "error message"}}); + +void BM_error_while_paused(benchmark::State& state) { + bool first_iter = true; + while (state.KeepRunning()) { + if (state.range(0) == 1 && state.thread_index() <= (state.threads() / 2)) { + assert(first_iter); + first_iter = false; + state.PauseTiming(); + state.SkipWithError("error message"); + } else { + state.PauseTiming(); + state.ResumeTiming(); + } + } +} +BENCHMARK(BM_error_while_paused)->Arg(1)->Arg(2)->ThreadRange(1, 8); +ADD_CASES("BM_error_while_paused", {{"/1/threads:1", true, "error message"}, + {"/1/threads:2", true, "error message"}, + {"/1/threads:4", true, "error message"}, + {"/1/threads:8", true, "error message"}, + {"/2/threads:1", false, ""}, + {"/2/threads:2", false, ""}, + {"/2/threads:4", false, ""}, + {"/2/threads:8", false, ""}}); + +int main(int argc, char* argv[]) { + benchmark::Initialize(&argc, argv); + + TestReporter test_reporter; + benchmark::RunSpecifiedBenchmarks(&test_reporter); + + typedef benchmark::BenchmarkReporter::Run Run; + auto EB = ExpectedResults.begin(); + + for (Run const& run : test_reporter.all_runs_) { + assert(EB != ExpectedResults.end()); + EB->CheckRun(run); + ++EB; + } + assert(EB == ExpectedResults.end()); + + return 0; +} diff --git a/third_party/benchmark/test/spec_arg_test.cc b/third_party/benchmark/test/spec_arg_test.cc new file mode 100644 index 0000000..06aafbe --- /dev/null +++ b/third_party/benchmark/test/spec_arg_test.cc @@ -0,0 +1,105 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark/benchmark.h" + +// Tests that we can override benchmark-spec value from FLAGS_benchmark_filter +// with argument to RunSpecifiedBenchmarks(...). + +namespace { + +class TestReporter : public benchmark::ConsoleReporter { + public: + bool ReportContext(const Context& context) override { + return ConsoleReporter::ReportContext(context); + }; + + void ReportRuns(const std::vector& report) override { + assert(report.size() == 1); + matched_functions.push_back(report[0].run_name.function_name); + ConsoleReporter::ReportRuns(report); + }; + + TestReporter() {} + + ~TestReporter() override {} + + const std::vector& GetMatchedFunctions() const { + return matched_functions; + } + + private: + std::vector matched_functions; +}; + +} // end namespace + +static void BM_NotChosen(benchmark::State& state) { + assert(false && "SHOULD NOT BE CALLED"); + for (auto _ : state) { + } +} +BENCHMARK(BM_NotChosen); + +static void BM_Chosen(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_Chosen); + +int main(int argc, char** argv) { + const std::string flag = "BM_NotChosen"; + + // Verify that argv specify --benchmark_filter=BM_NotChosen. + bool found = false; + for (int i = 0; i < argc; ++i) { + if (strcmp("--benchmark_filter=BM_NotChosen", argv[i]) == 0) { + found = true; + break; + } + } + assert(found); + + benchmark::Initialize(&argc, argv); + + // Check that the current flag value is reported accurately via the + // GetBenchmarkFilter() function. + if (flag != benchmark::GetBenchmarkFilter()) { + std::cerr + << "Seeing different value for flags. GetBenchmarkFilter() returns [" + << benchmark::GetBenchmarkFilter() << "] expected flag=[" << flag + << "]\n"; + return 1; + } + TestReporter test_reporter; + const char* const spec = "BM_Chosen"; + const size_t returned_count = + benchmark::RunSpecifiedBenchmarks(&test_reporter, spec); + assert(returned_count == 1); + const std::vector matched_functions = + test_reporter.GetMatchedFunctions(); + assert(matched_functions.size() == 1); + if (strcmp(spec, matched_functions.front().c_str()) != 0) { + std::cerr << "Expected benchmark [" << spec << "] to run, but got [" + << matched_functions.front() << "]\n"; + return 2; + } + + // Test that SetBenchmarkFilter works. + const std::string golden_value = "golden_value"; + benchmark::SetBenchmarkFilter(golden_value); + std::string current_value = benchmark::GetBenchmarkFilter(); + if (golden_value != current_value) { + std::cerr << "Expected [" << golden_value + << "] for --benchmark_filter but got [" << current_value << "]\n"; + return 3; + } + return 0; +} diff --git a/third_party/benchmark/test/spec_arg_verbosity_test.cc b/third_party/benchmark/test/spec_arg_verbosity_test.cc new file mode 100644 index 0000000..8f8eb6d --- /dev/null +++ b/third_party/benchmark/test/spec_arg_verbosity_test.cc @@ -0,0 +1,43 @@ +#include + +#include + +#include "benchmark/benchmark.h" + +// Tests that the user specified verbosity level can be get. +static void BM_Verbosity(benchmark::State& state) { + for (auto _ : state) { + } +} +BENCHMARK(BM_Verbosity); + +int main(int argc, char** argv) { + const int32_t flagv = 42; + + // Verify that argv specify --v=42. + bool found = false; + for (int i = 0; i < argc; ++i) { + if (strcmp("--v=42", argv[i]) == 0) { + found = true; + break; + } + } + if (!found) { + std::cerr << "This test requires '--v=42' to be passed as a command-line " + << "argument.\n"; + return 1; + } + + benchmark::Initialize(&argc, argv); + + // Check that the current flag value is reported accurately via the + // GetBenchmarkVerbosity() function. + if (flagv != benchmark::GetBenchmarkVerbosity()) { + std::cerr + << "Seeing different value for flags. GetBenchmarkVerbosity() returns [" + << benchmark::GetBenchmarkVerbosity() << "] expected flag=[" << flagv + << "]\n"; + return 1; + } + return 0; +} diff --git a/third_party/benchmark/test/state_assembly_test.cc b/third_party/benchmark/test/state_assembly_test.cc new file mode 100644 index 0000000..7ddbb3b --- /dev/null +++ b/third_party/benchmark/test/state_assembly_test.cc @@ -0,0 +1,68 @@ +#include + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wreturn-type" +#endif + +// clang-format off +extern "C" { + extern int ExternInt; + benchmark::State& GetState(); + void Fn(); +} +// clang-format on + +using benchmark::State; + +// CHECK-LABEL: test_for_auto_loop: +extern "C" int test_for_auto_loop() { + State& S = GetState(); + int x = 42; + // CHECK: [[CALL:call(q)*]] _ZN9benchmark5State16StartKeepRunningEv + // CHECK-NEXT: testq %rbx, %rbx + // CHECK-NEXT: je [[LOOP_END:.*]] + + for (auto _ : S) { + // CHECK: .L[[LOOP_HEAD:[a-zA-Z0-9_]+]]: + // CHECK-GNU-NEXT: subq $1, %rbx + // CHECK-CLANG-NEXT: {{(addq \$1, %rax|incq %rax|addq \$-1, %rbx)}} + // CHECK-NEXT: jne .L[[LOOP_HEAD]] + benchmark::DoNotOptimize(x); + } + // CHECK: [[LOOP_END]]: + // CHECK: [[CALL]] _ZN9benchmark5State17FinishKeepRunningEv + + // CHECK: movl $101, %eax + // CHECK: ret + return 101; +} + +// CHECK-LABEL: test_while_loop: +extern "C" int test_while_loop() { + State& S = GetState(); + int x = 42; + + // CHECK: j{{(e|mp)}} .L[[LOOP_HEADER:[a-zA-Z0-9_]+]] + // CHECK-NEXT: .L[[LOOP_BODY:[a-zA-Z0-9_]+]]: + while (S.KeepRunning()) { + // CHECK-GNU-NEXT: subq $1, %[[IREG:[a-z]+]] + // CHECK-CLANG-NEXT: {{(addq \$-1,|decq)}} %[[IREG:[a-z]+]] + // CHECK: movq %[[IREG]], [[DEST:.*]] + benchmark::DoNotOptimize(x); + } + // CHECK-DAG: movq [[DEST]], %[[IREG]] + // CHECK-DAG: testq %[[IREG]], %[[IREG]] + // CHECK-DAG: jne .L[[LOOP_BODY]] + // CHECK-DAG: .L[[LOOP_HEADER]]: + + // CHECK: cmpb $0 + // CHECK-NEXT: jne .L[[LOOP_END:[a-zA-Z0-9_]+]] + // CHECK: [[CALL:call(q)*]] _ZN9benchmark5State16StartKeepRunningEv + + // CHECK: .L[[LOOP_END]]: + // CHECK: [[CALL]] _ZN9benchmark5State17FinishKeepRunningEv + + // CHECK: movl $101, %eax + // CHECK: ret + return 101; +} diff --git a/third_party/benchmark/test/statistics_gtest.cc b/third_party/benchmark/test/statistics_gtest.cc new file mode 100644 index 0000000..48c7726 --- /dev/null +++ b/third_party/benchmark/test/statistics_gtest.cc @@ -0,0 +1,35 @@ +//===---------------------------------------------------------------------===// +// statistics_test - Unit tests for src/statistics.cc +//===---------------------------------------------------------------------===// + +#include "../src/statistics.h" +#include "gtest/gtest.h" + +namespace { +TEST(StatisticsTest, Mean) { + EXPECT_DOUBLE_EQ(benchmark::StatisticsMean({42, 42, 42, 42}), 42.0); + EXPECT_DOUBLE_EQ(benchmark::StatisticsMean({1, 2, 3, 4}), 2.5); + EXPECT_DOUBLE_EQ(benchmark::StatisticsMean({1, 2, 5, 10, 10, 14}), 7.0); +} + +TEST(StatisticsTest, Median) { + EXPECT_DOUBLE_EQ(benchmark::StatisticsMedian({42, 42, 42, 42}), 42.0); + EXPECT_DOUBLE_EQ(benchmark::StatisticsMedian({1, 2, 3, 4}), 2.5); + EXPECT_DOUBLE_EQ(benchmark::StatisticsMedian({1, 2, 5, 10, 10}), 5.0); +} + +TEST(StatisticsTest, StdDev) { + EXPECT_DOUBLE_EQ(benchmark::StatisticsStdDev({101, 101, 101, 101}), 0.0); + EXPECT_DOUBLE_EQ(benchmark::StatisticsStdDev({1, 2, 3}), 1.0); + EXPECT_DOUBLE_EQ(benchmark::StatisticsStdDev({2.5, 2.4, 3.3, 4.2, 5.1}), + 1.151086443322134); +} + +TEST(StatisticsTest, CV) { + EXPECT_DOUBLE_EQ(benchmark::StatisticsCV({101, 101, 101, 101}), 0.0); + EXPECT_DOUBLE_EQ(benchmark::StatisticsCV({1, 2, 3}), 1. / 2.); + ASSERT_NEAR(benchmark::StatisticsCV({2.5, 2.4, 3.3, 4.2, 5.1}), + 0.32888184094918121, 1e-15); +} + +} // end namespace diff --git a/third_party/benchmark/test/string_util_gtest.cc b/third_party/benchmark/test/string_util_gtest.cc new file mode 100644 index 0000000..67b4bc0 --- /dev/null +++ b/third_party/benchmark/test/string_util_gtest.cc @@ -0,0 +1,199 @@ +//===---------------------------------------------------------------------===// +// string_util_test - Unit tests for src/string_util.cc +//===---------------------------------------------------------------------===// + +#include + +#include "../src/internal_macros.h" +#include "../src/string_util.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace { +TEST(StringUtilTest, stoul) { + { + size_t pos = 0; + EXPECT_EQ(0ul, benchmark::stoul("0", &pos)); + EXPECT_EQ(1ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(7ul, benchmark::stoul("7", &pos)); + EXPECT_EQ(1ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(135ul, benchmark::stoul("135", &pos)); + EXPECT_EQ(3ul, pos); + } +#if ULONG_MAX == 0xFFFFFFFFul + { + size_t pos = 0; + EXPECT_EQ(0xFFFFFFFFul, benchmark::stoul("4294967295", &pos)); + EXPECT_EQ(10ul, pos); + } +#elif ULONG_MAX == 0xFFFFFFFFFFFFFFFFul + { + size_t pos = 0; + EXPECT_EQ(0xFFFFFFFFFFFFFFFFul, + benchmark::stoul("18446744073709551615", &pos)); + EXPECT_EQ(20ul, pos); + } +#endif + { + size_t pos = 0; + EXPECT_EQ(10ul, benchmark::stoul("1010", &pos, 2)); + EXPECT_EQ(4ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(520ul, benchmark::stoul("1010", &pos, 8)); + EXPECT_EQ(4ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(1010ul, benchmark::stoul("1010", &pos, 10)); + EXPECT_EQ(4ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(4112ul, benchmark::stoul("1010", &pos, 16)); + EXPECT_EQ(4ul, pos); + } + { + size_t pos = 0; + EXPECT_EQ(0xBEEFul, benchmark::stoul("BEEF", &pos, 16)); + EXPECT_EQ(4ul, pos); + } +#ifndef BENCHMARK_HAS_NO_EXCEPTIONS + { + ASSERT_THROW(std::ignore = benchmark::stoul("this is a test"), + std::invalid_argument); + } +#endif +} + +TEST(StringUtilTest, stoi){{size_t pos = 0; +EXPECT_EQ(0, benchmark::stoi("0", &pos)); +EXPECT_EQ(1ul, pos); +} // namespace +{ + size_t pos = 0; + EXPECT_EQ(-17, benchmark::stoi("-17", &pos)); + EXPECT_EQ(3ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(1357, benchmark::stoi("1357", &pos)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(10, benchmark::stoi("1010", &pos, 2)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(520, benchmark::stoi("1010", &pos, 8)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(1010, benchmark::stoi("1010", &pos, 10)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(4112, benchmark::stoi("1010", &pos, 16)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(0xBEEF, benchmark::stoi("BEEF", &pos, 16)); + EXPECT_EQ(4ul, pos); +} +#ifndef BENCHMARK_HAS_NO_EXCEPTIONS +{ + ASSERT_THROW(std::ignore = benchmark::stoi("this is a test"), + std::invalid_argument); +} +#endif +} + +TEST(StringUtilTest, stod){{size_t pos = 0; +EXPECT_EQ(0.0, benchmark::stod("0", &pos)); +EXPECT_EQ(1ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(-84.0, benchmark::stod("-84", &pos)); + EXPECT_EQ(3ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(1234.0, benchmark::stod("1234", &pos)); + EXPECT_EQ(4ul, pos); +} +{ + size_t pos = 0; + EXPECT_EQ(1.5, benchmark::stod("1.5", &pos)); + EXPECT_EQ(3ul, pos); +} +{ + size_t pos = 0; + /* Note: exactly representable as double */ + EXPECT_EQ(-1.25e+9, benchmark::stod("-1.25e+9", &pos)); + EXPECT_EQ(8ul, pos); +} +#ifndef BENCHMARK_HAS_NO_EXCEPTIONS +{ + ASSERT_THROW(std::ignore = benchmark::stod("this is a test"), + std::invalid_argument); +} +#endif +} + +TEST(StringUtilTest, StrSplit) { + EXPECT_EQ(benchmark::StrSplit("", ','), std::vector{}); + EXPECT_EQ(benchmark::StrSplit("hello", ','), + std::vector({"hello"})); + EXPECT_EQ(benchmark::StrSplit("hello,there,is,more", ','), + std::vector({"hello", "there", "is", "more"})); +} + +using HumanReadableFixture = ::testing::TestWithParam< + std::tuple>; + +INSTANTIATE_TEST_SUITE_P( + HumanReadableTests, HumanReadableFixture, + ::testing::Values( + std::make_tuple(0.0, benchmark::Counter::kIs1024, "0"), + std::make_tuple(999.0, benchmark::Counter::kIs1024, "999"), + std::make_tuple(1000.0, benchmark::Counter::kIs1024, "1000"), + std::make_tuple(1024.0, benchmark::Counter::kIs1024, "1Ki"), + std::make_tuple(1000 * 1000.0, benchmark::Counter::kIs1024, + "976\\.56.Ki"), + std::make_tuple(1024 * 1024.0, benchmark::Counter::kIs1024, "1Mi"), + std::make_tuple(1000 * 1000 * 1000.0, benchmark::Counter::kIs1024, + "953\\.674Mi"), + std::make_tuple(1024 * 1024 * 1024.0, benchmark::Counter::kIs1024, + "1Gi"), + std::make_tuple(0.0, benchmark::Counter::kIs1000, "0"), + std::make_tuple(999.0, benchmark::Counter::kIs1000, "999"), + std::make_tuple(1000.0, benchmark::Counter::kIs1000, "1k"), + std::make_tuple(1024.0, benchmark::Counter::kIs1000, "1.024k"), + std::make_tuple(1000 * 1000.0, benchmark::Counter::kIs1000, "1M"), + std::make_tuple(1024 * 1024.0, benchmark::Counter::kIs1000, + "1\\.04858M"), + std::make_tuple(1000 * 1000 * 1000.0, benchmark::Counter::kIs1000, + "1G"), + std::make_tuple(1024 * 1024 * 1024.0, benchmark::Counter::kIs1000, + "1\\.07374G"))); + +TEST_P(HumanReadableFixture, HumanReadableNumber) { + std::string str = benchmark::HumanReadableNumber(std::get<0>(GetParam()), + std::get<1>(GetParam())); + ASSERT_THAT(str, ::testing::MatchesRegex(std::get<2>(GetParam()))); +} + +} // end namespace diff --git a/third_party/benchmark/test/templated_fixture_test.cc b/third_party/benchmark/test/templated_fixture_test.cc new file mode 100644 index 0000000..af239c3 --- /dev/null +++ b/third_party/benchmark/test/templated_fixture_test.cc @@ -0,0 +1,28 @@ + +#include +#include + +#include "benchmark/benchmark.h" + +template +class MyFixture : public ::benchmark::Fixture { + public: + MyFixture() : data(0) {} + + T data; +}; + +BENCHMARK_TEMPLATE_F(MyFixture, Foo, int)(benchmark::State& st) { + for (auto _ : st) { + data += 1; + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(MyFixture, Bar, double)(benchmark::State& st) { + for (auto _ : st) { + data += 1.0; + } +} +BENCHMARK_REGISTER_F(MyFixture, Bar); + +BENCHMARK_MAIN(); diff --git a/third_party/benchmark/test/time_unit_gtest.cc b/third_party/benchmark/test/time_unit_gtest.cc new file mode 100644 index 0000000..484ecbc --- /dev/null +++ b/third_party/benchmark/test/time_unit_gtest.cc @@ -0,0 +1,37 @@ +#include "../include/benchmark/benchmark.h" +#include "gtest/gtest.h" + +namespace benchmark { +namespace internal { + +namespace { + +class DummyBenchmark : public Benchmark { + public: + DummyBenchmark() : Benchmark("dummy") {} + void Run(State&) override {} +}; + +TEST(DefaultTimeUnitTest, TimeUnitIsNotSet) { + DummyBenchmark benchmark; + EXPECT_EQ(benchmark.GetTimeUnit(), kNanosecond); +} + +TEST(DefaultTimeUnitTest, DefaultIsSet) { + DummyBenchmark benchmark; + EXPECT_EQ(benchmark.GetTimeUnit(), kNanosecond); + SetDefaultTimeUnit(kMillisecond); + EXPECT_EQ(benchmark.GetTimeUnit(), kMillisecond); +} + +TEST(DefaultTimeUnitTest, DefaultAndExplicitUnitIsSet) { + DummyBenchmark benchmark; + benchmark.Unit(kMillisecond); + SetDefaultTimeUnit(kMicrosecond); + + EXPECT_EQ(benchmark.GetTimeUnit(), kMillisecond); +} + +} // namespace +} // namespace internal +} // namespace benchmark diff --git a/third_party/benchmark/test/user_counters_tabular_test.cc b/third_party/benchmark/test/user_counters_tabular_test.cc new file mode 100644 index 0000000..cfc1ab0 --- /dev/null +++ b/third_party/benchmark/test/user_counters_tabular_test.cc @@ -0,0 +1,561 @@ + +#undef NDEBUG + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// @todo: this checks the full output at once; the rule for +// CounterSet1 was failing because it was not matching "^[-]+$". +// @todo: check that the counters are vertically aligned. +ADD_CASES(TC_ConsoleOut, + { + // keeping these lines long improves readability, so: + // clang-format off + {"^[-]+$", MR_Next}, + {"^Benchmark %s Time %s CPU %s Iterations %s Bar %s Bat %s Baz %s Foo %s Frob %s Lob$", MR_Next}, + {"^[-]+$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1 %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1 %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1_mean %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1_median %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1_stddev %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:1_cv %console_percentage_report [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*%$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2 %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2 %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2_mean %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2_median %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2_stddev %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_Counters_Tabular/repeats:2/threads:2_cv %console_percentage_report [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*% [ ]*%percentage[ ]*%$", MR_Next}, + {"^BM_CounterRates_Tabular/threads:%int %console_report [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s$", MR_Next}, + {"^BM_CounterRates_Tabular/threads:%int %console_report [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s$", MR_Next}, + {"^BM_CounterRates_Tabular/threads:%int %console_report [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s$", MR_Next}, + {"^BM_CounterRates_Tabular/threads:%int %console_report [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s$", MR_Next}, + {"^BM_CounterRates_Tabular/threads:%int %console_report [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s [ ]*%hrfloat/s$", MR_Next}, + {"^[-]+$", MR_Next}, + {"^Benchmark %s Time %s CPU %s Iterations %s Bar %s Baz %s Foo$", MR_Next}, + {"^[-]+$", MR_Next}, + {"^BM_CounterSet0_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet0_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet0_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet0_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet0_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet1_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet1_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet1_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet1_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet1_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^[-]+$", MR_Next}, + {"^Benchmark %s Time %s CPU %s Iterations %s Bat %s Baz %s Foo$", MR_Next}, + {"^[-]+$", MR_Next}, + {"^BM_CounterSet2_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet2_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet2_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet2_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$", MR_Next}, + {"^BM_CounterSet2_Tabular/threads:%int %console_report [ ]*%hrfloat [ ]*%hrfloat [ ]*%hrfloat$"}, + // clang-format on + }); +ADD_CASES(TC_CSVOut, {{"%csv_header," + "\"Bar\",\"Bat\",\"Baz\",\"Foo\",\"Frob\",\"Lob\""}}); + +// ========================================================================= // +// ------------------------- Tabular Counters Output ----------------------- // +// ========================================================================= // + +void BM_Counters_Tabular(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters.insert({ + {"Foo", {1, bm::Counter::kAvgThreads}}, + {"Bar", {2, bm::Counter::kAvgThreads}}, + {"Baz", {4, bm::Counter::kAvgThreads}}, + {"Bat", {8, bm::Counter::kAvgThreads}}, + {"Frob", {16, bm::Counter::kAvgThreads}}, + {"Lob", {32, bm::Counter::kAvgThreads}}, + }); +} +BENCHMARK(BM_Counters_Tabular)->ThreadRange(1, 2)->Repetitions(2); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1_mean\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1_median\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1_stddev\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:1_cv\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:1\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"cv\",$", MR_Next}, + {"\"aggregate_unit\": \"percentage\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); + +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 1,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 2,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 1,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 2,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:2_median\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 1,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 2,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:2_stddev\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 1,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 2,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Tabular/repeats:2/threads:2_cv\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 1,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Tabular/repeats:2/threads:2\",$", + MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 2,$", MR_Next}, + {"\"aggregate_name\": \"cv\",$", MR_Next}, + {"\"aggregate_unit\": \"percentage\",$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1_mean\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1_median\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1_stddev\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:1_cv\",%csv_cv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2_mean\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2_median\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2_stddev\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_Tabular/repeats:2/threads:2_cv\",%csv_cv_report," + "%float,%float,%float,%float,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckTabular(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "Foo", EQ, 1); + CHECK_COUNTER_VALUE(e, int, "Bar", EQ, 2); + CHECK_COUNTER_VALUE(e, int, "Baz", EQ, 4); + CHECK_COUNTER_VALUE(e, int, "Bat", EQ, 8); + CHECK_COUNTER_VALUE(e, int, "Frob", EQ, 16); + CHECK_COUNTER_VALUE(e, int, "Lob", EQ, 32); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_Tabular/repeats:2/threads:1$", + &CheckTabular); +CHECK_BENCHMARK_RESULTS("BM_Counters_Tabular/repeats:2/threads:2$", + &CheckTabular); + +// ========================================================================= // +// -------------------- Tabular+Rate Counters Output ----------------------- // +// ========================================================================= // + +void BM_CounterRates_Tabular(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters.insert({ + {"Foo", {1, bm::Counter::kAvgThreadsRate}}, + {"Bar", {2, bm::Counter::kAvgThreadsRate}}, + {"Baz", {4, bm::Counter::kAvgThreadsRate}}, + {"Bat", {8, bm::Counter::kAvgThreadsRate}}, + {"Frob", {16, bm::Counter::kAvgThreadsRate}}, + {"Lob", {32, bm::Counter::kAvgThreadsRate}}, + }); +} +BENCHMARK(BM_CounterRates_Tabular)->ThreadRange(1, 16); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_CounterRates_Tabular/threads:%int\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_CounterRates_Tabular/threads:%int\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float,$", MR_Next}, + {"\"Frob\": %float,$", MR_Next}, + {"\"Lob\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_CounterRates_Tabular/threads:%int\",%csv_report," + "%float,%float,%float,%float,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckTabularRate(Results const& e) { + double t = e.DurationCPUTime(); + CHECK_FLOAT_COUNTER_VALUE(e, "Foo", EQ, 1. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "Bar", EQ, 2. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "Baz", EQ, 4. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "Bat", EQ, 8. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "Frob", EQ, 16. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "Lob", EQ, 32. / t, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_CounterRates_Tabular/threads:%int", + &CheckTabularRate); + +// ========================================================================= // +// ------------------------- Tabular Counters Output ----------------------- // +// ========================================================================= // + +// set only some of the counters +void BM_CounterSet0_Tabular(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters.insert({ + {"Foo", {10, bm::Counter::kAvgThreads}}, + {"Bar", {20, bm::Counter::kAvgThreads}}, + {"Baz", {40, bm::Counter::kAvgThreads}}, + }); +} +BENCHMARK(BM_CounterSet0_Tabular)->ThreadRange(1, 16); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_CounterSet0_Tabular/threads:%int\",$"}, + {"\"family_index\": 2,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_CounterSet0_Tabular/threads:%int\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_CounterSet0_Tabular/threads:%int\",%csv_report," + "%float,,%float,%float,,"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckSet0(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "Foo", EQ, 10); + CHECK_COUNTER_VALUE(e, int, "Bar", EQ, 20); + CHECK_COUNTER_VALUE(e, int, "Baz", EQ, 40); +} +CHECK_BENCHMARK_RESULTS("BM_CounterSet0_Tabular", &CheckSet0); + +// again. +void BM_CounterSet1_Tabular(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters.insert({ + {"Foo", {15, bm::Counter::kAvgThreads}}, + {"Bar", {25, bm::Counter::kAvgThreads}}, + {"Baz", {45, bm::Counter::kAvgThreads}}, + }); +} +BENCHMARK(BM_CounterSet1_Tabular)->ThreadRange(1, 16); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_CounterSet1_Tabular/threads:%int\",$"}, + {"\"family_index\": 3,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_CounterSet1_Tabular/threads:%int\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bar\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_CounterSet1_Tabular/threads:%int\",%csv_report," + "%float,,%float,%float,,"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckSet1(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "Foo", EQ, 15); + CHECK_COUNTER_VALUE(e, int, "Bar", EQ, 25); + CHECK_COUNTER_VALUE(e, int, "Baz", EQ, 45); +} +CHECK_BENCHMARK_RESULTS("BM_CounterSet1_Tabular/threads:%int", &CheckSet1); + +// ========================================================================= // +// ------------------------- Tabular Counters Output ----------------------- // +// ========================================================================= // + +// set only some of the counters, different set now. +void BM_CounterSet2_Tabular(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters.insert({ + {"Foo", {10, bm::Counter::kAvgThreads}}, + {"Bat", {30, bm::Counter::kAvgThreads}}, + {"Baz", {40, bm::Counter::kAvgThreads}}, + }); +} +BENCHMARK(BM_CounterSet2_Tabular)->ThreadRange(1, 16); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_CounterSet2_Tabular/threads:%int\",$"}, + {"\"family_index\": 4,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_CounterSet2_Tabular/threads:%int\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"Bat\": %float,$", MR_Next}, + {"\"Baz\": %float,$", MR_Next}, + {"\"Foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_CounterSet2_Tabular/threads:%int\",%csv_report," + ",%float,%float,%float,,"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckSet2(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "Foo", EQ, 10); + CHECK_COUNTER_VALUE(e, int, "Bat", EQ, 30); + CHECK_COUNTER_VALUE(e, int, "Baz", EQ, 40); +} +CHECK_BENCHMARK_RESULTS("BM_CounterSet2_Tabular", &CheckSet2); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/user_counters_test.cc b/third_party/benchmark/test/user_counters_test.cc new file mode 100644 index 0000000..22252ac --- /dev/null +++ b/third_party/benchmark/test/user_counters_test.cc @@ -0,0 +1,561 @@ + +#undef NDEBUG + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// ========================================================================= // +// ---------------------- Testing Prologue Output -------------------------- // +// ========================================================================= // + +// clang-format off + +ADD_CASES(TC_ConsoleOut, + {{"^[-]+$", MR_Next}, + {"^Benchmark %s Time %s CPU %s Iterations UserCounters...$", MR_Next}, + {"^[-]+$", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"%csv_header,\"bar\",\"foo\""}}); + +// clang-format on + +// ========================================================================= // +// ------------------------- Simple Counters Output ------------------------ // +// ========================================================================= // + +void BM_Counters_Simple(benchmark::State& state) { + for (auto _ : state) { + } + state.counters["foo"] = 1; + state.counters["bar"] = 2 * static_cast(state.iterations()); +} +BENCHMARK(BM_Counters_Simple); +ADD_CASES(TC_ConsoleOut, + {{"^BM_Counters_Simple %console_report bar=%hrfloat foo=%hrfloat$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Counters_Simple\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Simple\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_Simple\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckSimple(Results const& e) { + double its = e.NumIterations(); + CHECK_COUNTER_VALUE(e, int, "foo", EQ, 1); + // check that the value of bar is within 0.1% of the expected value + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. * its, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_Simple", &CheckSimple); + +// ========================================================================= // +// --------------------- Counters+Items+Bytes/s Output --------------------- // +// ========================================================================= // + +namespace { +int num_calls1 = 0; +} +void BM_Counters_WithBytesAndItemsPSec(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + state.counters["foo"] = 1; + state.counters["bar"] = ++num_calls1; + state.SetBytesProcessed(364); + state.SetItemsProcessed(150); +} +BENCHMARK(BM_Counters_WithBytesAndItemsPSec); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_WithBytesAndItemsPSec %console_report " + "bar=%hrfloat bytes_per_second=%hrfloat/s " + "foo=%hrfloat items_per_second=%hrfloat/s$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_WithBytesAndItemsPSec\",$"}, + {"\"family_index\": 1,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_WithBytesAndItemsPSec\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"bytes_per_second\": %float,$", MR_Next}, + {"\"foo\": %float,$", MR_Next}, + {"\"items_per_second\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_WithBytesAndItemsPSec\"," + "%csv_bytes_items_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckBytesAndItemsPSec(Results const& e) { + double t = e.DurationCPUTime(); // this (and not real time) is the time used + CHECK_COUNTER_VALUE(e, int, "foo", EQ, 1); + CHECK_COUNTER_VALUE(e, int, "bar", EQ, num_calls1); + // check that the values are within 0.1% of the expected values + CHECK_FLOAT_RESULT_VALUE(e, "bytes_per_second", EQ, 364. / t, 0.001); + CHECK_FLOAT_RESULT_VALUE(e, "items_per_second", EQ, 150. / t, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_WithBytesAndItemsPSec", + &CheckBytesAndItemsPSec); + +// ========================================================================= // +// ------------------------- Rate Counters Output -------------------------- // +// ========================================================================= // + +void BM_Counters_Rate(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kIsRate}; + state.counters["bar"] = bm::Counter{2, bm::Counter::kIsRate}; +} +BENCHMARK(BM_Counters_Rate); +ADD_CASES( + TC_ConsoleOut, + {{"^BM_Counters_Rate %console_report bar=%hrfloat/s foo=%hrfloat/s$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Counters_Rate\",$"}, + {"\"family_index\": 2,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Rate\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_Rate\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckRate(Results const& e) { + double t = e.DurationCPUTime(); // this (and not real time) is the time used + // check that the values are within 0.1% of the expected values + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, 1. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. / t, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_Rate", &CheckRate); + +// ========================================================================= // +// ----------------------- Inverted Counters Output ------------------------ // +// ========================================================================= // + +void BM_Invert(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{0.0001, bm::Counter::kInvert}; + state.counters["bar"] = bm::Counter{10000, bm::Counter::kInvert}; +} +BENCHMARK(BM_Invert); +ADD_CASES(TC_ConsoleOut, + {{"^BM_Invert %console_report bar=%hrfloatu foo=%hrfloatk$"}}); +ADD_CASES(TC_JSONOut, {{"\"name\": \"BM_Invert\",$"}, + {"\"family_index\": 3,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Invert\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Invert\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckInvert(Results const& e) { + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, 10000, 0.0001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 0.0001, 0.0001); +} +CHECK_BENCHMARK_RESULTS("BM_Invert", &CheckInvert); + +// ========================================================================= // +// --------------------- InvertedRate Counters Output ---------------------- // +// ========================================================================= // + +void BM_Counters_InvertedRate(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = + bm::Counter{1, bm::Counter::kIsRate | bm::Counter::kInvert}; + state.counters["bar"] = + bm::Counter{8192, bm::Counter::kIsRate | bm::Counter::kInvert}; +} +BENCHMARK(BM_Counters_InvertedRate); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_InvertedRate %console_report " + "bar=%hrfloats foo=%hrfloats$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_InvertedRate\",$"}, + {"\"family_index\": 4,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_InvertedRate\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_InvertedRate\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckInvertedRate(Results const& e) { + double t = e.DurationCPUTime(); // this (and not real time) is the time used + // check that the values are within 0.1% of the expected values + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, t / 8192.0, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_InvertedRate", &CheckInvertedRate); + +// ========================================================================= // +// ------------------------- Thread Counters Output ------------------------ // +// ========================================================================= // + +void BM_Counters_Threads(benchmark::State& state) { + for (auto _ : state) { + } + state.counters["foo"] = 1; + state.counters["bar"] = 2; +} +BENCHMARK(BM_Counters_Threads)->ThreadRange(1, 8); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_Threads/threads:%int %console_report " + "bar=%hrfloat foo=%hrfloat$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Threads/threads:%int\",$"}, + {"\"family_index\": 5,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Threads/threads:%int\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES( + TC_CSVOut, + {{"^\"BM_Counters_Threads/threads:%int\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckThreads(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "foo", EQ, e.NumThreads()); + CHECK_COUNTER_VALUE(e, int, "bar", EQ, 2 * e.NumThreads()); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_Threads/threads:%int", &CheckThreads); + +// ========================================================================= // +// ---------------------- ThreadAvg Counters Output ------------------------ // +// ========================================================================= // + +void BM_Counters_AvgThreads(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kAvgThreads}; + state.counters["bar"] = bm::Counter{2, bm::Counter::kAvgThreads}; +} +BENCHMARK(BM_Counters_AvgThreads)->ThreadRange(1, 8); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_AvgThreads/threads:%int " + "%console_report bar=%hrfloat foo=%hrfloat$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_AvgThreads/threads:%int\",$"}, + {"\"family_index\": 6,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_AvgThreads/threads:%int\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES( + TC_CSVOut, + {{"^\"BM_Counters_AvgThreads/threads:%int\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckAvgThreads(Results const& e) { + CHECK_COUNTER_VALUE(e, int, "foo", EQ, 1); + CHECK_COUNTER_VALUE(e, int, "bar", EQ, 2); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_AvgThreads/threads:%int", + &CheckAvgThreads); + +// ========================================================================= // +// ---------------------- ThreadAvg Counters Output ------------------------ // +// ========================================================================= // + +void BM_Counters_AvgThreadsRate(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kAvgThreadsRate}; + state.counters["bar"] = bm::Counter{2, bm::Counter::kAvgThreadsRate}; +} +BENCHMARK(BM_Counters_AvgThreadsRate)->ThreadRange(1, 8); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_AvgThreadsRate/threads:%int " + "%console_report bar=%hrfloat/s foo=%hrfloat/s$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_AvgThreadsRate/threads:%int\",$"}, + {"\"family_index\": 7,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_AvgThreadsRate/threads:%int\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_AvgThreadsRate/" + "threads:%int\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckAvgThreadsRate(Results const& e) { + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, 1. / e.DurationCPUTime(), 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. / e.DurationCPUTime(), 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_AvgThreadsRate/threads:%int", + &CheckAvgThreadsRate); + +// ========================================================================= // +// ------------------- IterationInvariant Counters Output ------------------ // +// ========================================================================= // + +void BM_Counters_IterationInvariant(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kIsIterationInvariant}; + state.counters["bar"] = bm::Counter{2, bm::Counter::kIsIterationInvariant}; +} +BENCHMARK(BM_Counters_IterationInvariant); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_IterationInvariant %console_report " + "bar=%hrfloat foo=%hrfloat$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_IterationInvariant\",$"}, + {"\"family_index\": 8,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_IterationInvariant\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_IterationInvariant\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckIterationInvariant(Results const& e) { + double its = e.NumIterations(); + // check that the values are within 0.1% of the expected value + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, its, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. * its, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_IterationInvariant", + &CheckIterationInvariant); + +// ========================================================================= // +// ----------------- IterationInvariantRate Counters Output ---------------- // +// ========================================================================= // + +void BM_Counters_kIsIterationInvariantRate(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = + bm::Counter{1, bm::Counter::kIsIterationInvariantRate}; + state.counters["bar"] = + bm::Counter{2, bm::Counter::kIsRate | bm::Counter::kIsIterationInvariant}; +} +BENCHMARK(BM_Counters_kIsIterationInvariantRate); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_kIsIterationInvariantRate " + "%console_report bar=%hrfloat/s foo=%hrfloat/s$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_kIsIterationInvariantRate\",$"}, + {"\"family_index\": 9,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_kIsIterationInvariantRate\",$", + MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_kIsIterationInvariantRate\",%csv_report," + "%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckIsIterationInvariantRate(Results const& e) { + double its = e.NumIterations(); + double t = e.DurationCPUTime(); // this (and not real time) is the time used + // check that the values are within 0.1% of the expected values + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, its * 1. / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, its * 2. / t, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_kIsIterationInvariantRate", + &CheckIsIterationInvariantRate); + +// ========================================================================= // +// --------------------- AvgIterations Counters Output --------------------- // +// ========================================================================= // + +void BM_Counters_AvgIterations(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kAvgIterations}; + state.counters["bar"] = bm::Counter{2, bm::Counter::kAvgIterations}; +} +BENCHMARK(BM_Counters_AvgIterations); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_AvgIterations %console_report " + "bar=%hrfloat foo=%hrfloat$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_AvgIterations\",$"}, + {"\"family_index\": 10,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_AvgIterations\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, + {{"^\"BM_Counters_AvgIterations\",%csv_report,%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckAvgIterations(Results const& e) { + double its = e.NumIterations(); + // check that the values are within 0.1% of the expected value + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, 1. / its, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. / its, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_AvgIterations", &CheckAvgIterations); + +// ========================================================================= // +// ------------------- AvgIterationsRate Counters Output ------------------- // +// ========================================================================= // + +void BM_Counters_kAvgIterationsRate(benchmark::State& state) { + for (auto _ : state) { + // This test requires a non-zero CPU time to avoid divide-by-zero + auto iterations = double(state.iterations()) * double(state.iterations()); + benchmark::DoNotOptimize(iterations); + } + namespace bm = benchmark; + state.counters["foo"] = bm::Counter{1, bm::Counter::kAvgIterationsRate}; + state.counters["bar"] = + bm::Counter{2, bm::Counter::kIsRate | bm::Counter::kAvgIterations}; +} +BENCHMARK(BM_Counters_kAvgIterationsRate); +ADD_CASES(TC_ConsoleOut, {{"^BM_Counters_kAvgIterationsRate " + "%console_report bar=%hrfloat/s foo=%hrfloat/s$"}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_kAvgIterationsRate\",$"}, + {"\"family_index\": 11,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_kAvgIterationsRate\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 1,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"bar\": %float,$", MR_Next}, + {"\"foo\": %float$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_CSVOut, {{"^\"BM_Counters_kAvgIterationsRate\",%csv_report," + "%float,%float$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckAvgIterationsRate(Results const& e) { + double its = e.NumIterations(); + double t = e.DurationCPUTime(); // this (and not real time) is the time used + // check that the values are within 0.1% of the expected values + CHECK_FLOAT_COUNTER_VALUE(e, "foo", EQ, 1. / its / t, 0.001); + CHECK_FLOAT_COUNTER_VALUE(e, "bar", EQ, 2. / its / t, 0.001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_kAvgIterationsRate", + &CheckAvgIterationsRate); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/test/user_counters_thousands_test.cc b/third_party/benchmark/test/user_counters_thousands_test.cc new file mode 100644 index 0000000..fc15383 --- /dev/null +++ b/third_party/benchmark/test/user_counters_thousands_test.cc @@ -0,0 +1,186 @@ + +#undef NDEBUG + +#include "benchmark/benchmark.h" +#include "output_test.h" + +// ========================================================================= // +// ------------------------ Thousands Customisation ------------------------ // +// ========================================================================= // + +void BM_Counters_Thousands(benchmark::State& state) { + for (auto _ : state) { + } + namespace bm = benchmark; + state.counters.insert({ + {"t0_1000000DefaultBase", + bm::Counter(1000 * 1000, bm::Counter::kDefaults)}, + {"t1_1000000Base1000", bm::Counter(1000 * 1000, bm::Counter::kDefaults, + bm::Counter::OneK::kIs1000)}, + {"t2_1000000Base1024", bm::Counter(1000 * 1000, bm::Counter::kDefaults, + bm::Counter::OneK::kIs1024)}, + {"t3_1048576Base1000", bm::Counter(1024 * 1024, bm::Counter::kDefaults, + bm::Counter::OneK::kIs1000)}, + {"t4_1048576Base1024", bm::Counter(1024 * 1024, bm::Counter::kDefaults, + bm::Counter::OneK::kIs1024)}, + }); +} +BENCHMARK(BM_Counters_Thousands)->Repetitions(2); +ADD_CASES( + TC_ConsoleOut, + { + {"^BM_Counters_Thousands/repeats:2 %console_report " + "t0_1000000DefaultBase=1M " + "t1_1000000Base1000=1M t2_1000000Base1024=976.56[23]Ki " + "t3_1048576Base1000=1.04858M t4_1048576Base1024=1Mi$"}, + {"^BM_Counters_Thousands/repeats:2 %console_report " + "t0_1000000DefaultBase=1M " + "t1_1000000Base1000=1M t2_1000000Base1024=976.56[23]Ki " + "t3_1048576Base1000=1.04858M t4_1048576Base1024=1Mi$"}, + {"^BM_Counters_Thousands/repeats:2_mean %console_report " + "t0_1000000DefaultBase=1M t1_1000000Base1000=1M " + "t2_1000000Base1024=976.56[23]Ki t3_1048576Base1000=1.04858M " + "t4_1048576Base1024=1Mi$"}, + {"^BM_Counters_Thousands/repeats:2_median %console_report " + "t0_1000000DefaultBase=1M t1_1000000Base1000=1M " + "t2_1000000Base1024=976.56[23]Ki t3_1048576Base1000=1.04858M " + "t4_1048576Base1024=1Mi$"}, + {"^BM_Counters_Thousands/repeats:2_stddev %console_time_only_report [ " + "]*2 t0_1000000DefaultBase=0 t1_1000000Base1000=0 " + "t2_1000000Base1024=0 t3_1048576Base1000=0 t4_1048576Base1024=0$"}, + }); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Thousands/repeats:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Thousands/repeats:2\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 0,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"t0_1000000DefaultBase\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t1_1000000Base1000\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t2_1000000Base1024\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t3_1048576Base1000\": 1\\.048576(0)*e\\+(0)*6,$", MR_Next}, + {"\"t4_1048576Base1024\": 1\\.048576(0)*e\\+(0)*6$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Thousands/repeats:2\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Thousands/repeats:2\",$", MR_Next}, + {"\"run_type\": \"iteration\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"repetition_index\": 1,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"iterations\": %int,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"t0_1000000DefaultBase\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t1_1000000Base1000\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t2_1000000Base1024\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t3_1048576Base1000\": 1\\.048576(0)*e\\+(0)*6,$", MR_Next}, + {"\"t4_1048576Base1024\": 1\\.048576(0)*e\\+(0)*6$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Thousands/repeats:2_mean\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Thousands/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"mean\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"t0_1000000DefaultBase\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t1_1000000Base1000\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t2_1000000Base1024\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t3_1048576Base1000\": 1\\.048576(0)*e\\+(0)*6,$", MR_Next}, + {"\"t4_1048576Base1024\": 1\\.048576(0)*e\\+(0)*6$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Thousands/repeats:2_median\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Thousands/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"median\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"t0_1000000DefaultBase\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t1_1000000Base1000\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t2_1000000Base1024\": 1\\.(0)*e\\+(0)*6,$", MR_Next}, + {"\"t3_1048576Base1000\": 1\\.048576(0)*e\\+(0)*6,$", MR_Next}, + {"\"t4_1048576Base1024\": 1\\.048576(0)*e\\+(0)*6$", MR_Next}, + {"}", MR_Next}}); +ADD_CASES(TC_JSONOut, + {{"\"name\": \"BM_Counters_Thousands/repeats:2_stddev\",$"}, + {"\"family_index\": 0,$", MR_Next}, + {"\"per_family_instance_index\": 0,$", MR_Next}, + {"\"run_name\": \"BM_Counters_Thousands/repeats:2\",$", MR_Next}, + {"\"run_type\": \"aggregate\",$", MR_Next}, + {"\"repetitions\": 2,$", MR_Next}, + {"\"threads\": 1,$", MR_Next}, + {"\"aggregate_name\": \"stddev\",$", MR_Next}, + {"\"aggregate_unit\": \"time\",$", MR_Next}, + {"\"iterations\": 2,$", MR_Next}, + {"\"real_time\": %float,$", MR_Next}, + {"\"cpu_time\": %float,$", MR_Next}, + {"\"time_unit\": \"ns\",$", MR_Next}, + {"\"t0_1000000DefaultBase\": 0\\.(0)*e\\+(0)*,$", MR_Next}, + {"\"t1_1000000Base1000\": 0\\.(0)*e\\+(0)*,$", MR_Next}, + {"\"t2_1000000Base1024\": 0\\.(0)*e\\+(0)*,$", MR_Next}, + {"\"t3_1048576Base1000\": 0\\.(0)*e\\+(0)*,$", MR_Next}, + {"\"t4_1048576Base1024\": 0\\.(0)*e\\+(0)*$", MR_Next}, + {"}", MR_Next}}); + +ADD_CASES( + TC_CSVOut, + {{"^\"BM_Counters_Thousands/" + "repeats:2\",%csv_report,1e\\+(0)*6,1e\\+(0)*6,1e\\+(0)*6,1\\.04858e\\+(" + "0)*6,1\\.04858e\\+(0)*6$"}, + {"^\"BM_Counters_Thousands/" + "repeats:2\",%csv_report,1e\\+(0)*6,1e\\+(0)*6,1e\\+(0)*6,1\\.04858e\\+(" + "0)*6,1\\.04858e\\+(0)*6$"}, + {"^\"BM_Counters_Thousands/" + "repeats:2_mean\",%csv_report,1e\\+(0)*6,1e\\+(0)*6,1e\\+(0)*6,1\\." + "04858e\\+(0)*6,1\\.04858e\\+(0)*6$"}, + {"^\"BM_Counters_Thousands/" + "repeats:2_median\",%csv_report,1e\\+(0)*6,1e\\+(0)*6,1e\\+(0)*6,1\\." + "04858e\\+(0)*6,1\\.04858e\\+(0)*6$"}, + {"^\"BM_Counters_Thousands/repeats:2_stddev\",%csv_report,0,0,0,0,0$"}}); +// VS2013 does not allow this function to be passed as a lambda argument +// to CHECK_BENCHMARK_RESULTS() +void CheckThousands(Results const& e) { + if (e.name != "BM_Counters_Thousands/repeats:2") + return; // Do not check the aggregates! + + // check that the values are within 0.01% of the expected values + CHECK_FLOAT_COUNTER_VALUE(e, "t0_1000000DefaultBase", EQ, 1000 * 1000, + 0.0001); + CHECK_FLOAT_COUNTER_VALUE(e, "t1_1000000Base1000", EQ, 1000 * 1000, 0.0001); + CHECK_FLOAT_COUNTER_VALUE(e, "t2_1000000Base1024", EQ, 1000 * 1000, 0.0001); + CHECK_FLOAT_COUNTER_VALUE(e, "t3_1048576Base1000", EQ, 1024 * 1024, 0.0001); + CHECK_FLOAT_COUNTER_VALUE(e, "t4_1048576Base1024", EQ, 1024 * 1024, 0.0001); +} +CHECK_BENCHMARK_RESULTS("BM_Counters_Thousands", &CheckThousands); + +// ========================================================================= // +// --------------------------- TEST CASES END ------------------------------ // +// ========================================================================= // + +int main(int argc, char* argv[]) { RunOutputTests(argc, argv); } diff --git a/third_party/benchmark/tools/BUILD.bazel b/third_party/benchmark/tools/BUILD.bazel new file mode 100644 index 0000000..8ef6a86 --- /dev/null +++ b/third_party/benchmark/tools/BUILD.bazel @@ -0,0 +1,20 @@ +load("@tools_pip_deps//:requirements.bzl", "requirement") + +py_library( + name = "gbench", + srcs = glob(["gbench/*.py"]), + deps = [ + requirement("numpy"), + requirement("scipy"), + ], +) + +py_binary( + name = "compare", + srcs = ["compare.py"], + imports = ["."], + python_version = "PY3", + deps = [ + ":gbench", + ], +) diff --git a/third_party/benchmark/tools/compare.py b/third_party/benchmark/tools/compare.py new file mode 100755 index 0000000..7572520 --- /dev/null +++ b/third_party/benchmark/tools/compare.py @@ -0,0 +1,523 @@ +#!/usr/bin/env python3 + +# type: ignore + +""" +compare.py - versatile benchmark output compare tool +""" + +import argparse +import json +import os +import sys +import unittest +from argparse import ArgumentParser + +import gbench +from gbench import report, util + + +def check_inputs(in1, in2, flags): + """ + Perform checking on the user provided inputs and diagnose any abnormalities + """ + in1_kind, in1_err = util.classify_input_file(in1) + in2_kind, in2_err = util.classify_input_file(in2) + output_file = util.find_benchmark_flag("--benchmark_out=", flags) + output_type = util.find_benchmark_flag("--benchmark_out_format=", flags) + if ( + in1_kind == util.IT_Executable + and in2_kind == util.IT_Executable + and output_file + ): + print( + ( + "WARNING: '--benchmark_out=%s' will be passed to both " + "benchmarks causing it to be overwritten" + ) + % output_file + ) + if in1_kind == util.IT_JSON and in2_kind == util.IT_JSON: + # When both sides are JSON the only supported flag is + # --benchmark_filter= + for flag in util.remove_benchmark_flags("--benchmark_filter=", flags): + print( + "WARNING: passing %s has no effect since both " + "inputs are JSON" % flag + ) + if output_type is not None and output_type != "json": + print( + ( + "ERROR: passing '--benchmark_out_format=%s' to 'compare.py`" + " is not supported." + ) + % output_type + ) + sys.exit(1) + + +def create_parser(): + parser = ArgumentParser( + description="versatile benchmark output compare tool" + ) + + parser.add_argument( + "-a", + "--display_aggregates_only", + dest="display_aggregates_only", + action="store_true", + help="If there are repetitions, by default, we display everything - the" + " actual runs, and the aggregates computed. Sometimes, it is " + "desirable to only view the aggregates. E.g. when there are a lot " + "of repetitions. Do note that only the display is affected. " + "Internally, all the actual runs are still used, e.g. for U test.", + ) + + parser.add_argument( + "--no-color", + dest="color", + default=True, + action="store_false", + help="Do not use colors in the terminal output", + ) + + parser.add_argument( + "-d", + "--dump_to_json", + dest="dump_to_json", + help="Additionally, dump benchmark comparison output to this file in JSON format.", + ) + + utest = parser.add_argument_group() + utest.add_argument( + "--no-utest", + dest="utest", + default=True, + action="store_false", + help="The tool can do a two-tailed Mann-Whitney U test with the null hypothesis that it is equally likely that a randomly selected value from one sample will be less than or greater than a randomly selected value from a second sample.\nWARNING: requires **LARGE** (no less than {}) number of repetitions to be meaningful!\nThe test is being done by default, if at least {} repetitions were done.\nThis option can disable the U Test.".format( + report.UTEST_OPTIMAL_REPETITIONS, report.UTEST_MIN_REPETITIONS + ), + ) + alpha_default = 0.05 + utest.add_argument( + "--alpha", + dest="utest_alpha", + default=alpha_default, + type=float, + help=( + "significance level alpha. if the calculated p-value is below this value, then the result is said to be statistically significant and the null hypothesis is rejected.\n(default: %0.4f)" + ) + % alpha_default, + ) + + subparsers = parser.add_subparsers( + help="This tool has multiple modes of operation:", dest="mode" + ) + + parser_a = subparsers.add_parser( + "benchmarks", + help="The most simple use-case, compare all the output of these two benchmarks", + ) + baseline = parser_a.add_argument_group("baseline", "The benchmark baseline") + baseline.add_argument( + "test_baseline", + metavar="test_baseline", + type=argparse.FileType("r"), + nargs=1, + help="A benchmark executable or JSON output file", + ) + contender = parser_a.add_argument_group( + "contender", "The benchmark that will be compared against the baseline" + ) + contender.add_argument( + "test_contender", + metavar="test_contender", + type=argparse.FileType("r"), + nargs=1, + help="A benchmark executable or JSON output file", + ) + parser_a.add_argument( + "benchmark_options", + metavar="benchmark_options", + nargs=argparse.REMAINDER, + help="Arguments to pass when running benchmark executables", + ) + + parser_b = subparsers.add_parser( + "filters", help="Compare filter one with the filter two of benchmark" + ) + baseline = parser_b.add_argument_group("baseline", "The benchmark baseline") + baseline.add_argument( + "test", + metavar="test", + type=argparse.FileType("r"), + nargs=1, + help="A benchmark executable or JSON output file", + ) + baseline.add_argument( + "filter_baseline", + metavar="filter_baseline", + type=str, + nargs=1, + help="The first filter, that will be used as baseline", + ) + contender = parser_b.add_argument_group( + "contender", "The benchmark that will be compared against the baseline" + ) + contender.add_argument( + "filter_contender", + metavar="filter_contender", + type=str, + nargs=1, + help="The second filter, that will be compared against the baseline", + ) + parser_b.add_argument( + "benchmark_options", + metavar="benchmark_options", + nargs=argparse.REMAINDER, + help="Arguments to pass when running benchmark executables", + ) + + parser_c = subparsers.add_parser( + "benchmarksfiltered", + help="Compare filter one of first benchmark with filter two of the second benchmark", + ) + baseline = parser_c.add_argument_group("baseline", "The benchmark baseline") + baseline.add_argument( + "test_baseline", + metavar="test_baseline", + type=argparse.FileType("r"), + nargs=1, + help="A benchmark executable or JSON output file", + ) + baseline.add_argument( + "filter_baseline", + metavar="filter_baseline", + type=str, + nargs=1, + help="The first filter, that will be used as baseline", + ) + contender = parser_c.add_argument_group( + "contender", "The benchmark that will be compared against the baseline" + ) + contender.add_argument( + "test_contender", + metavar="test_contender", + type=argparse.FileType("r"), + nargs=1, + help="The second benchmark executable or JSON output file, that will be compared against the baseline", + ) + contender.add_argument( + "filter_contender", + metavar="filter_contender", + type=str, + nargs=1, + help="The second filter, that will be compared against the baseline", + ) + parser_c.add_argument( + "benchmark_options", + metavar="benchmark_options", + nargs=argparse.REMAINDER, + help="Arguments to pass when running benchmark executables", + ) + + return parser + + +def main(): + # Parse the command line flags + parser = create_parser() + args, unknown_args = parser.parse_known_args() + if args.mode is None: + parser.print_help() + exit(1) + assert not unknown_args + benchmark_options = args.benchmark_options + + if args.mode == "benchmarks": + test_baseline = args.test_baseline[0].name + test_contender = args.test_contender[0].name + filter_baseline = "" + filter_contender = "" + + # NOTE: if test_baseline == test_contender, you are analyzing the stdev + + description = "Comparing %s to %s" % (test_baseline, test_contender) + elif args.mode == "filters": + test_baseline = args.test[0].name + test_contender = args.test[0].name + filter_baseline = args.filter_baseline[0] + filter_contender = args.filter_contender[0] + + # NOTE: if filter_baseline == filter_contender, you are analyzing the + # stdev + + description = "Comparing %s to %s (from %s)" % ( + filter_baseline, + filter_contender, + args.test[0].name, + ) + elif args.mode == "benchmarksfiltered": + test_baseline = args.test_baseline[0].name + test_contender = args.test_contender[0].name + filter_baseline = args.filter_baseline[0] + filter_contender = args.filter_contender[0] + + # NOTE: if test_baseline == test_contender and + # filter_baseline == filter_contender, you are analyzing the stdev + + description = "Comparing %s (from %s) to %s (from %s)" % ( + filter_baseline, + test_baseline, + filter_contender, + test_contender, + ) + else: + # should never happen + print("Unrecognized mode of operation: '%s'" % args.mode) + parser.print_help() + exit(1) + + check_inputs(test_baseline, test_contender, benchmark_options) + + if args.display_aggregates_only: + benchmark_options += ["--benchmark_display_aggregates_only=true"] + + options_baseline = [] + options_contender = [] + + if filter_baseline and filter_contender: + options_baseline = ["--benchmark_filter=%s" % filter_baseline] + options_contender = ["--benchmark_filter=%s" % filter_contender] + + # Run the benchmarks and report the results + json1 = json1_orig = gbench.util.sort_benchmark_results( + gbench.util.run_or_load_benchmark( + test_baseline, benchmark_options + options_baseline + ) + ) + json2 = json2_orig = gbench.util.sort_benchmark_results( + gbench.util.run_or_load_benchmark( + test_contender, benchmark_options + options_contender + ) + ) + + # Now, filter the benchmarks so that the difference report can work + if filter_baseline and filter_contender: + replacement = "[%s vs. %s]" % (filter_baseline, filter_contender) + json1 = gbench.report.filter_benchmark( + json1_orig, filter_baseline, replacement + ) + json2 = gbench.report.filter_benchmark( + json2_orig, filter_contender, replacement + ) + + diff_report = gbench.report.get_difference_report(json1, json2, args.utest) + output_lines = gbench.report.print_difference_report( + diff_report, + args.display_aggregates_only, + args.utest, + args.utest_alpha, + args.color, + ) + print(description) + for ln in output_lines: + print(ln) + + # Optionally, diff and output to JSON + if args.dump_to_json is not None: + with open(args.dump_to_json, "w") as f_json: + json.dump(diff_report, f_json, indent=1) + + +class TestParser(unittest.TestCase): + def setUp(self): + self.parser = create_parser() + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "gbench", "Inputs" + ) + self.testInput0 = os.path.join(testInputs, "test1_run1.json") + self.testInput1 = os.path.join(testInputs, "test1_run2.json") + + def test_benchmarks_basic(self): + parsed = self.parser.parse_args( + ["benchmarks", self.testInput0, self.testInput1] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertFalse(parsed.benchmark_options) + + def test_benchmarks_basic_without_utest(self): + parsed = self.parser.parse_args( + ["--no-utest", "benchmarks", self.testInput0, self.testInput1] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertFalse(parsed.utest) + self.assertEqual(parsed.utest_alpha, 0.05) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertFalse(parsed.benchmark_options) + + def test_benchmarks_basic_display_aggregates_only(self): + parsed = self.parser.parse_args( + ["-a", "benchmarks", self.testInput0, self.testInput1] + ) + self.assertTrue(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertFalse(parsed.benchmark_options) + + def test_benchmarks_basic_with_utest_alpha(self): + parsed = self.parser.parse_args( + ["--alpha=0.314", "benchmarks", self.testInput0, self.testInput1] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.utest_alpha, 0.314) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertFalse(parsed.benchmark_options) + + def test_benchmarks_basic_without_utest_with_utest_alpha(self): + parsed = self.parser.parse_args( + [ + "--no-utest", + "--alpha=0.314", + "benchmarks", + self.testInput0, + self.testInput1, + ] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertFalse(parsed.utest) + self.assertEqual(parsed.utest_alpha, 0.314) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertFalse(parsed.benchmark_options) + + def test_benchmarks_with_remainder(self): + parsed = self.parser.parse_args( + ["benchmarks", self.testInput0, self.testInput1, "d"] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertEqual(parsed.benchmark_options, ["d"]) + + def test_benchmarks_with_remainder_after_doubleminus(self): + parsed = self.parser.parse_args( + ["benchmarks", self.testInput0, self.testInput1, "--", "e"] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarks") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertEqual(parsed.benchmark_options, ["e"]) + + def test_filters_basic(self): + parsed = self.parser.parse_args(["filters", self.testInput0, "c", "d"]) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "filters") + self.assertEqual(parsed.test[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.filter_contender[0], "d") + self.assertFalse(parsed.benchmark_options) + + def test_filters_with_remainder(self): + parsed = self.parser.parse_args( + ["filters", self.testInput0, "c", "d", "e"] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "filters") + self.assertEqual(parsed.test[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.filter_contender[0], "d") + self.assertEqual(parsed.benchmark_options, ["e"]) + + def test_filters_with_remainder_after_doubleminus(self): + parsed = self.parser.parse_args( + ["filters", self.testInput0, "c", "d", "--", "f"] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "filters") + self.assertEqual(parsed.test[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.filter_contender[0], "d") + self.assertEqual(parsed.benchmark_options, ["f"]) + + def test_benchmarksfiltered_basic(self): + parsed = self.parser.parse_args( + ["benchmarksfiltered", self.testInput0, "c", self.testInput1, "e"] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarksfiltered") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertEqual(parsed.filter_contender[0], "e") + self.assertFalse(parsed.benchmark_options) + + def test_benchmarksfiltered_with_remainder(self): + parsed = self.parser.parse_args( + [ + "benchmarksfiltered", + self.testInput0, + "c", + self.testInput1, + "e", + "f", + ] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarksfiltered") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertEqual(parsed.filter_contender[0], "e") + self.assertEqual(parsed.benchmark_options[0], "f") + + def test_benchmarksfiltered_with_remainder_after_doubleminus(self): + parsed = self.parser.parse_args( + [ + "benchmarksfiltered", + self.testInput0, + "c", + self.testInput1, + "e", + "--", + "g", + ] + ) + self.assertFalse(parsed.display_aggregates_only) + self.assertTrue(parsed.utest) + self.assertEqual(parsed.mode, "benchmarksfiltered") + self.assertEqual(parsed.test_baseline[0].name, self.testInput0) + self.assertEqual(parsed.filter_baseline[0], "c") + self.assertEqual(parsed.test_contender[0].name, self.testInput1) + self.assertEqual(parsed.filter_contender[0], "e") + self.assertEqual(parsed.benchmark_options[0], "g") + + +if __name__ == "__main__": + # unittest.main() + main() + +# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 +# kate: tab-width: 4; replace-tabs on; indent-width 4; tab-indents: off; +# kate: indent-mode python; remove-trailing-spaces modified; diff --git a/third_party/benchmark/tools/gbench/Inputs/test1_run1.json b/third_party/benchmark/tools/gbench/Inputs/test1_run1.json new file mode 100644 index 0000000..9daed0b --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test1_run1.json @@ -0,0 +1,127 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_SameTimes", + "iterations": 1000, + "real_time": 10, + "cpu_time": 10, + "time_unit": "ns" + }, + { + "name": "BM_2xFaster", + "iterations": 1000, + "real_time": 50, + "cpu_time": 50, + "time_unit": "ns" + }, + { + "name": "BM_2xSlower", + "iterations": 1000, + "real_time": 50, + "cpu_time": 50, + "time_unit": "ns" + }, + { + "name": "BM_1PercentFaster", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_1PercentSlower", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_10PercentFaster", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_10PercentSlower", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_100xSlower", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_100xFaster", + "iterations": 1000, + "real_time": 10000, + "cpu_time": 10000, + "time_unit": "ns" + }, + { + "name": "BM_10PercentCPUToTime", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_ThirdFaster", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "MyComplexityTest_BigO", + "run_name": "MyComplexityTest", + "run_type": "aggregate", + "aggregate_name": "BigO", + "cpu_coefficient": 4.2749856294592886e+00, + "real_coefficient": 6.4789275289789780e+00, + "big_o": "N", + "time_unit": "ns" + }, + { + "name": "MyComplexityTest_RMS", + "run_name": "MyComplexityTest", + "run_type": "aggregate", + "aggregate_name": "RMS", + "rms": 4.5097802512472874e-03 + }, + { + "name": "BM_NotBadTimeUnit", + "iterations": 1000, + "real_time": 0.4, + "cpu_time": 0.5, + "time_unit": "s" + }, + { + "name": "BM_DifferentTimeUnit", + "iterations": 1, + "real_time": 1, + "cpu_time": 1, + "time_unit": "s" + }, + { + "name": "BM_hasLabel", + "label": "a label", + "iterations": 1, + "real_time": 1, + "cpu_time": 1, + "time_unit": "s" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test1_run2.json b/third_party/benchmark/tools/gbench/Inputs/test1_run2.json new file mode 100644 index 0000000..dc52970 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test1_run2.json @@ -0,0 +1,127 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_SameTimes", + "iterations": 1000, + "real_time": 10, + "cpu_time": 10, + "time_unit": "ns" + }, + { + "name": "BM_2xFaster", + "iterations": 1000, + "real_time": 25, + "cpu_time": 25, + "time_unit": "ns" + }, + { + "name": "BM_2xSlower", + "iterations": 20833333, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_1PercentFaster", + "iterations": 1000, + "real_time": 98.9999999, + "cpu_time": 98.9999999, + "time_unit": "ns" + }, + { + "name": "BM_1PercentSlower", + "iterations": 1000, + "real_time": 100.9999999, + "cpu_time": 100.9999999, + "time_unit": "ns" + }, + { + "name": "BM_10PercentFaster", + "iterations": 1000, + "real_time": 90, + "cpu_time": 90, + "time_unit": "ns" + }, + { + "name": "BM_10PercentSlower", + "iterations": 1000, + "real_time": 110, + "cpu_time": 110, + "time_unit": "ns" + }, + { + "name": "BM_100xSlower", + "iterations": 1000, + "real_time": 1.0000e+04, + "cpu_time": 1.0000e+04, + "time_unit": "ns" + }, + { + "name": "BM_100xFaster", + "iterations": 1000, + "real_time": 100, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_10PercentCPUToTime", + "iterations": 1000, + "real_time": 110, + "cpu_time": 90, + "time_unit": "ns" + }, + { + "name": "BM_ThirdFaster", + "iterations": 1000, + "real_time": 66.665, + "cpu_time": 66.664, + "time_unit": "ns" + }, + { + "name": "MyComplexityTest_BigO", + "run_name": "MyComplexityTest", + "run_type": "aggregate", + "aggregate_name": "BigO", + "cpu_coefficient": 5.6215779594361486e+00, + "real_coefficient": 5.6288314793554610e+00, + "big_o": "N", + "time_unit": "ns" + }, + { + "name": "MyComplexityTest_RMS", + "run_name": "MyComplexityTest", + "run_type": "aggregate", + "aggregate_name": "RMS", + "rms": 3.3128901852342174e-03 + }, + { + "name": "BM_NotBadTimeUnit", + "iterations": 1000, + "real_time": 0.04, + "cpu_time": 0.6, + "time_unit": "s" + }, + { + "name": "BM_DifferentTimeUnit", + "iterations": 1, + "real_time": 1, + "cpu_time": 1, + "time_unit": "ns" + }, + { + "name": "BM_hasLabel", + "label": "a label", + "iterations": 1, + "real_time": 1, + "cpu_time": 1, + "time_unit": "s" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test2_run.json b/third_party/benchmark/tools/gbench/Inputs/test2_run.json new file mode 100644 index 0000000..15bc698 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test2_run.json @@ -0,0 +1,81 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_Hi", + "iterations": 1234, + "real_time": 42, + "cpu_time": 24, + "time_unit": "ms" + }, + { + "name": "BM_Zero", + "iterations": 1000, + "real_time": 10, + "cpu_time": 10, + "time_unit": "ns" + }, + { + "name": "BM_Zero/4", + "iterations": 4000, + "real_time": 40, + "cpu_time": 40, + "time_unit": "ns" + }, + { + "name": "Prefix/BM_Zero", + "iterations": 2000, + "real_time": 20, + "cpu_time": 20, + "time_unit": "ns" + }, + { + "name": "Prefix/BM_Zero/3", + "iterations": 3000, + "real_time": 30, + "cpu_time": 30, + "time_unit": "ns" + }, + { + "name": "BM_One", + "iterations": 5000, + "real_time": 5, + "cpu_time": 5, + "time_unit": "ns" + }, + { + "name": "BM_One/4", + "iterations": 2000, + "real_time": 20, + "cpu_time": 20, + "time_unit": "ns" + }, + { + "name": "Prefix/BM_One", + "iterations": 1000, + "real_time": 10, + "cpu_time": 10, + "time_unit": "ns" + }, + { + "name": "Prefix/BM_One/3", + "iterations": 1500, + "real_time": 15, + "cpu_time": 15, + "time_unit": "ns" + }, + { + "name": "BM_Bye", + "iterations": 5321, + "real_time": 11, + "cpu_time": 63, + "time_unit": "ns" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test3_run0.json b/third_party/benchmark/tools/gbench/Inputs/test3_run0.json new file mode 100644 index 0000000..49f8b06 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test3_run0.json @@ -0,0 +1,65 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_One", + "run_type": "aggregate", + "iterations": 1000, + "real_time": 10, + "cpu_time": 100, + "time_unit": "ns" + }, + { + "name": "BM_Two", + "iterations": 1000, + "real_time": 9, + "cpu_time": 90, + "time_unit": "ns" + }, + { + "name": "BM_Two", + "iterations": 1000, + "real_time": 8, + "cpu_time": 86, + "time_unit": "ns" + }, + { + "name": "short", + "run_type": "aggregate", + "iterations": 1000, + "real_time": 8, + "cpu_time": 80, + "time_unit": "ns" + }, + { + "name": "short", + "run_type": "aggregate", + "iterations": 1000, + "real_time": 8, + "cpu_time": 77, + "time_unit": "ns" + }, + { + "name": "medium", + "run_type": "iteration", + "iterations": 1000, + "real_time": 8, + "cpu_time": 80, + "time_unit": "ns" + }, + { + "name": "medium", + "run_type": "iteration", + "iterations": 1000, + "real_time": 9, + "cpu_time": 82, + "time_unit": "ns" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test3_run1.json b/third_party/benchmark/tools/gbench/Inputs/test3_run1.json new file mode 100644 index 0000000..acc5ba1 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test3_run1.json @@ -0,0 +1,65 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_One", + "iterations": 1000, + "real_time": 9, + "cpu_time": 110, + "time_unit": "ns" + }, + { + "name": "BM_Two", + "run_type": "aggregate", + "iterations": 1000, + "real_time": 10, + "cpu_time": 89, + "time_unit": "ns" + }, + { + "name": "BM_Two", + "iterations": 1000, + "real_time": 7, + "cpu_time": 72, + "time_unit": "ns" + }, + { + "name": "short", + "run_type": "aggregate", + "iterations": 1000, + "real_time": 7, + "cpu_time": 75, + "time_unit": "ns" + }, + { + "name": "short", + "run_type": "aggregate", + "iterations": 762, + "real_time": 4.54, + "cpu_time": 66.6, + "time_unit": "ns" + }, + { + "name": "short", + "run_type": "iteration", + "iterations": 1000, + "real_time": 800, + "cpu_time": 1, + "time_unit": "ns" + }, + { + "name": "medium", + "run_type": "iteration", + "iterations": 1200, + "real_time": 5, + "cpu_time": 53, + "time_unit": "ns" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test4_run.json b/third_party/benchmark/tools/gbench/Inputs/test4_run.json new file mode 100644 index 0000000..eaa005f --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test4_run.json @@ -0,0 +1,96 @@ +{ + "benchmarks": [ + { + "name": "99 family 0 instance 0 repetition 0", + "run_type": "iteration", + "family_index": 0, + "per_family_instance_index": 0, + "repetition_index": 0 + }, + { + "name": "98 family 0 instance 0 repetition 1", + "run_type": "iteration", + "family_index": 0, + "per_family_instance_index": 0, + "repetition_index": 1 + }, + { + "name": "97 family 0 instance 0 aggregate", + "run_type": "aggregate", + "family_index": 0, + "per_family_instance_index": 0, + "aggregate_name": "9 aggregate" + }, + + + { + "name": "96 family 0 instance 1 repetition 0", + "run_type": "iteration", + "family_index": 0, + "per_family_instance_index": 1, + "repetition_index": 0 + }, + { + "name": "95 family 0 instance 1 repetition 1", + "run_type": "iteration", + "family_index": 0, + "per_family_instance_index": 1, + "repetition_index": 1 + }, + { + "name": "94 family 0 instance 1 aggregate", + "run_type": "aggregate", + "family_index": 0, + "per_family_instance_index": 1, + "aggregate_name": "9 aggregate" + }, + + + + + { + "name": "93 family 1 instance 0 repetition 0", + "run_type": "iteration", + "family_index": 1, + "per_family_instance_index": 0, + "repetition_index": 0 + }, + { + "name": "92 family 1 instance 0 repetition 1", + "run_type": "iteration", + "family_index": 1, + "per_family_instance_index": 0, + "repetition_index": 1 + }, + { + "name": "91 family 1 instance 0 aggregate", + "run_type": "aggregate", + "family_index": 1, + "per_family_instance_index": 0, + "aggregate_name": "9 aggregate" + }, + + + { + "name": "90 family 1 instance 1 repetition 0", + "run_type": "iteration", + "family_index": 1, + "per_family_instance_index": 1, + "repetition_index": 0 + }, + { + "name": "89 family 1 instance 1 repetition 1", + "run_type": "iteration", + "family_index": 1, + "per_family_instance_index": 1, + "repetition_index": 1 + }, + { + "name": "88 family 1 instance 1 aggregate", + "run_type": "aggregate", + "family_index": 1, + "per_family_instance_index": 1, + "aggregate_name": "9 aggregate" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test4_run0.json b/third_party/benchmark/tools/gbench/Inputs/test4_run0.json new file mode 100644 index 0000000..54cf127 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test4_run0.json @@ -0,0 +1,21 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "whocares", + "run_type": "aggregate", + "aggregate_name": "zz", + "aggregate_unit": "percentage", + "iterations": 1000, + "real_time": 0.01, + "cpu_time": 0.10, + "time_unit": "ns" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test4_run1.json b/third_party/benchmark/tools/gbench/Inputs/test4_run1.json new file mode 100644 index 0000000..25d5605 --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test4_run1.json @@ -0,0 +1,21 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "whocares", + "run_type": "aggregate", + "aggregate_name": "zz", + "aggregate_unit": "percentage", + "iterations": 1000, + "real_time": 0.005, + "cpu_time": 0.15, + "time_unit": "ns" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test5_run0.json b/third_party/benchmark/tools/gbench/Inputs/test5_run0.json new file mode 100644 index 0000000..074103b --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test5_run0.json @@ -0,0 +1,18 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_ManyRepetitions", + "iterations": 1000, + "real_time": 1, + "cpu_time": 1000, + "time_unit": "s" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/Inputs/test5_run1.json b/third_party/benchmark/tools/gbench/Inputs/test5_run1.json new file mode 100644 index 0000000..430df9f --- /dev/null +++ b/third_party/benchmark/tools/gbench/Inputs/test5_run1.json @@ -0,0 +1,18 @@ +{ + "context": { + "date": "2016-08-02 17:44:46", + "num_cpus": 4, + "mhz_per_cpu": 4228, + "cpu_scaling_enabled": false, + "library_build_type": "release" + }, + "benchmarks": [ + { + "name": "BM_ManyRepetitions", + "iterations": 1000, + "real_time": 1000, + "cpu_time": 1, + "time_unit": "s" + } + ] +} diff --git a/third_party/benchmark/tools/gbench/__init__.py b/third_party/benchmark/tools/gbench/__init__.py new file mode 100644 index 0000000..9212568 --- /dev/null +++ b/third_party/benchmark/tools/gbench/__init__.py @@ -0,0 +1,8 @@ +"""Google Benchmark tooling""" + +__author__ = "Eric Fiselier" +__email__ = "eric@efcs.ca" +__versioninfo__ = (0, 5, 0) +__version__ = ".".join(str(v) for v in __versioninfo__) + "dev" + +__all__ = [] # type: ignore diff --git a/third_party/benchmark/tools/gbench/report.py b/third_party/benchmark/tools/gbench/report.py new file mode 100644 index 0000000..7158fd1 --- /dev/null +++ b/third_party/benchmark/tools/gbench/report.py @@ -0,0 +1,1619 @@ +# type: ignore + +""" +report.py - Utilities for reporting statistics about benchmark results +""" + +import copy +import os +import random +import re +import unittest + +from numpy import array +from scipy.stats import gmean, mannwhitneyu + + +class BenchmarkColor(object): + def __init__(self, name, code): + self.name = name + self.code = code + + def __repr__(self): + return "%s%r" % (self.__class__.__name__, (self.name, self.code)) + + def __format__(self, format): + return self.code + + +# Benchmark Colors Enumeration +BC_NONE = BenchmarkColor("NONE", "") +BC_MAGENTA = BenchmarkColor("MAGENTA", "\033[95m") +BC_CYAN = BenchmarkColor("CYAN", "\033[96m") +BC_OKBLUE = BenchmarkColor("OKBLUE", "\033[94m") +BC_OKGREEN = BenchmarkColor("OKGREEN", "\033[32m") +BC_HEADER = BenchmarkColor("HEADER", "\033[92m") +BC_WARNING = BenchmarkColor("WARNING", "\033[93m") +BC_WHITE = BenchmarkColor("WHITE", "\033[97m") +BC_FAIL = BenchmarkColor("FAIL", "\033[91m") +BC_ENDC = BenchmarkColor("ENDC", "\033[0m") +BC_BOLD = BenchmarkColor("BOLD", "\033[1m") +BC_UNDERLINE = BenchmarkColor("UNDERLINE", "\033[4m") + +UTEST_MIN_REPETITIONS = 2 +UTEST_OPTIMAL_REPETITIONS = 9 # Lowest reasonable number, More is better. +UTEST_COL_NAME = "_pvalue" + +_TIME_UNIT_TO_SECONDS_MULTIPLIER = { + "s": 1.0, + "ms": 1e-3, + "us": 1e-6, + "ns": 1e-9, +} + + +def color_format(use_color, fmt_str, *args, **kwargs): + """ + Return the result of 'fmt_str.format(*args, **kwargs)' after transforming + 'args' and 'kwargs' according to the value of 'use_color'. If 'use_color' + is False then all color codes in 'args' and 'kwargs' are replaced with + the empty string. + """ + assert use_color is True or use_color is False + if not use_color: + args = [ + arg if not isinstance(arg, BenchmarkColor) else BC_NONE + for arg in args + ] + kwargs = { + key: arg if not isinstance(arg, BenchmarkColor) else BC_NONE + for key, arg in kwargs.items() + } + return fmt_str.format(*args, **kwargs) + + +def find_longest_name(benchmark_list): + """ + Return the length of the longest benchmark name in a given list of + benchmark JSON objects + """ + longest_name = 1 + for bc in benchmark_list: + if len(bc["name"]) > longest_name: + longest_name = len(bc["name"]) + return longest_name + + +def calculate_change(old_val, new_val): + """ + Return a float representing the decimal change between old_val and new_val. + """ + if old_val == 0 and new_val == 0: + return 0.0 + if old_val == 0: + return float(new_val - old_val) / (float(old_val + new_val) / 2) + return float(new_val - old_val) / abs(old_val) + + +def filter_benchmark(json_orig, family, replacement=""): + """ + Apply a filter to the json, and only leave the 'family' of benchmarks. + """ + regex = re.compile(family) + filtered = {} + filtered["benchmarks"] = [] + for be in json_orig["benchmarks"]: + if not regex.search(be["name"]): + continue + filteredbench = copy.deepcopy(be) # Do NOT modify the old name! + filteredbench["name"] = regex.sub(replacement, filteredbench["name"]) + filtered["benchmarks"].append(filteredbench) + return filtered + + +def get_unique_benchmark_names(json): + """ + While *keeping* the order, give all the unique 'names' used for benchmarks. + """ + seen = set() + uniqued = [ + x["name"] + for x in json["benchmarks"] + if x["name"] not in seen and (seen.add(x["name"]) or True) + ] + return uniqued + + +def intersect(list1, list2): + """ + Given two lists, get a new list consisting of the elements only contained + in *both of the input lists*, while preserving the ordering. + """ + return [x for x in list1 if x in list2] + + +def is_potentially_comparable_benchmark(x): + return "time_unit" in x and "real_time" in x and "cpu_time" in x + + +def partition_benchmarks(json1, json2): + """ + While preserving the ordering, find benchmarks with the same names in + both of the inputs, and group them. + (i.e. partition/filter into groups with common name) + """ + json1_unique_names = get_unique_benchmark_names(json1) + json2_unique_names = get_unique_benchmark_names(json2) + names = intersect(json1_unique_names, json2_unique_names) + partitions = [] + for name in names: + time_unit = None + # Pick the time unit from the first entry of the lhs benchmark. + # We should be careful not to crash with unexpected input. + for x in json1["benchmarks"]: + if x["name"] == name and is_potentially_comparable_benchmark(x): + time_unit = x["time_unit"] + break + if time_unit is None: + continue + # Filter by name and time unit. + # All the repetitions are assumed to be comparable. + lhs = [ + x + for x in json1["benchmarks"] + if x["name"] == name and x["time_unit"] == time_unit + ] + rhs = [ + x + for x in json2["benchmarks"] + if x["name"] == name and x["time_unit"] == time_unit + ] + partitions.append([lhs, rhs]) + return partitions + + +def get_timedelta_field_as_seconds(benchmark, field_name): + """ + Get value of field_name field of benchmark, which is time with time unit + time_unit, as time in seconds. + """ + timedelta = benchmark[field_name] + time_unit = benchmark.get("time_unit", "s") + return timedelta * _TIME_UNIT_TO_SECONDS_MULTIPLIER.get(time_unit) + + +def calculate_geomean(json): + """ + Extract all real/cpu times from all the benchmarks as seconds, + and calculate their geomean. + """ + times = [] + for benchmark in json["benchmarks"]: + if "run_type" in benchmark and benchmark["run_type"] == "aggregate": + continue + times.append( + [ + get_timedelta_field_as_seconds(benchmark, "real_time"), + get_timedelta_field_as_seconds(benchmark, "cpu_time"), + ] + ) + return gmean(times) if times else array([]) + + +def extract_field(partition, field_name): + # The count of elements may be different. We want *all* of them. + lhs = [x[field_name] for x in partition[0]] + rhs = [x[field_name] for x in partition[1]] + return [lhs, rhs] + + +def calc_utest(timings_cpu, timings_time): + min_rep_cnt = min( + len(timings_time[0]), + len(timings_time[1]), + len(timings_cpu[0]), + len(timings_cpu[1]), + ) + + # Does *everything* has at least UTEST_MIN_REPETITIONS repetitions? + if min_rep_cnt < UTEST_MIN_REPETITIONS: + return False, None, None + + time_pvalue = mannwhitneyu( + timings_time[0], timings_time[1], alternative="two-sided" + ).pvalue + cpu_pvalue = mannwhitneyu( + timings_cpu[0], timings_cpu[1], alternative="two-sided" + ).pvalue + + return (min_rep_cnt >= UTEST_OPTIMAL_REPETITIONS), cpu_pvalue, time_pvalue + + +def print_utest(bc_name, utest, utest_alpha, first_col_width, use_color=True): + def get_utest_color(pval): + return BC_FAIL if pval >= utest_alpha else BC_OKGREEN + + # Check if we failed miserably with minimum required repetitions for utest + if ( + not utest["have_optimal_repetitions"] + and utest["cpu_pvalue"] is None + and utest["time_pvalue"] is None + ): + return [] + + dsc = "U Test, Repetitions: {} vs {}".format( + utest["nr_of_repetitions"], utest["nr_of_repetitions_other"] + ) + dsc_color = BC_OKGREEN + + # We still got some results to show but issue a warning about it. + if not utest["have_optimal_repetitions"]: + dsc_color = BC_WARNING + dsc += ". WARNING: Results unreliable! {}+ repetitions recommended.".format( + UTEST_OPTIMAL_REPETITIONS + ) + + special_str = "{}{:<{}s}{endc}{}{:16.4f}{endc}{}{:16.4f}{endc}{} {}" + + return [ + color_format( + use_color, + special_str, + BC_HEADER, + "{}{}".format(bc_name, UTEST_COL_NAME), + first_col_width, + get_utest_color(utest["time_pvalue"]), + utest["time_pvalue"], + get_utest_color(utest["cpu_pvalue"]), + utest["cpu_pvalue"], + dsc_color, + dsc, + endc=BC_ENDC, + ) + ] + + +def get_difference_report(json1, json2, utest=False): + """ + Calculate and report the difference between each test of two benchmarks + runs specified as 'json1' and 'json2'. Output is another json containing + relevant details for each test run. + """ + assert utest is True or utest is False + + diff_report = [] + partitions = partition_benchmarks(json1, json2) + for partition in partitions: + benchmark_name = partition[0][0]["name"] + label = partition[0][0]["label"] if "label" in partition[0][0] else "" + time_unit = partition[0][0]["time_unit"] + measurements = [] + utest_results = {} + # Careful, we may have different repetition count. + for i in range(min(len(partition[0]), len(partition[1]))): + bn = partition[0][i] + other_bench = partition[1][i] + measurements.append( + { + "real_time": bn["real_time"], + "cpu_time": bn["cpu_time"], + "real_time_other": other_bench["real_time"], + "cpu_time_other": other_bench["cpu_time"], + "time": calculate_change( + bn["real_time"], other_bench["real_time"] + ), + "cpu": calculate_change( + bn["cpu_time"], other_bench["cpu_time"] + ), + } + ) + + # After processing the whole partition, if requested, do the U test. + if utest: + timings_cpu = extract_field(partition, "cpu_time") + timings_time = extract_field(partition, "real_time") + have_optimal_repetitions, cpu_pvalue, time_pvalue = calc_utest( + timings_cpu, timings_time + ) + if cpu_pvalue is not None and time_pvalue is not None: + utest_results = { + "have_optimal_repetitions": have_optimal_repetitions, + "cpu_pvalue": cpu_pvalue, + "time_pvalue": time_pvalue, + "nr_of_repetitions": len(timings_cpu[0]), + "nr_of_repetitions_other": len(timings_cpu[1]), + } + + # Store only if we had any measurements for given benchmark. + # E.g. partition_benchmarks will filter out the benchmarks having + # time units which are not compatible with other time units in the + # benchmark suite. + if measurements: + run_type = ( + partition[0][0]["run_type"] + if "run_type" in partition[0][0] + else "" + ) + aggregate_name = ( + partition[0][0]["aggregate_name"] + if run_type == "aggregate" + and "aggregate_name" in partition[0][0] + else "" + ) + diff_report.append( + { + "name": benchmark_name, + "label": label, + "measurements": measurements, + "time_unit": time_unit, + "run_type": run_type, + "aggregate_name": aggregate_name, + "utest": utest_results, + } + ) + + lhs_gmean = calculate_geomean(json1) + rhs_gmean = calculate_geomean(json2) + if lhs_gmean.any() and rhs_gmean.any(): + diff_report.append( + { + "name": "OVERALL_GEOMEAN", + "label": "", + "measurements": [ + { + "real_time": lhs_gmean[0], + "cpu_time": lhs_gmean[1], + "real_time_other": rhs_gmean[0], + "cpu_time_other": rhs_gmean[1], + "time": calculate_change(lhs_gmean[0], rhs_gmean[0]), + "cpu": calculate_change(lhs_gmean[1], rhs_gmean[1]), + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + } + ) + + return diff_report + + +def print_difference_report( + json_diff_report, + include_aggregates_only=False, + utest=False, + utest_alpha=0.05, + use_color=True, +): + """ + Calculate and report the difference between each test of two benchmarks + runs specified as 'json1' and 'json2'. + """ + assert utest is True or utest is False + + def get_color(res): + if res > 0.05: + return BC_FAIL + elif res > -0.07: + return BC_WHITE + else: + return BC_CYAN + + first_col_width = find_longest_name(json_diff_report) + first_col_width = max(first_col_width, len("Benchmark")) + first_col_width += len(UTEST_COL_NAME) + first_line = "{:<{}s}Time CPU Time Old Time New CPU Old CPU New".format( + "Benchmark", 12 + first_col_width + ) + output_strs = [first_line, "-" * len(first_line)] + + fmt_str = "{}{:<{}s}{endc}{}{:+16.4f}{endc}{}{:+16.4f}{endc}{:14.0f}{:14.0f}{endc}{:14.0f}{:14.0f}" + for benchmark in json_diff_report: + # *If* we were asked to only include aggregates, + # and if it is non-aggregate, then don't print it. + if ( + not include_aggregates_only + or "run_type" not in benchmark + or benchmark["run_type"] == "aggregate" + ): + for measurement in benchmark["measurements"]: + output_strs += [ + color_format( + use_color, + fmt_str, + BC_HEADER, + benchmark["name"], + first_col_width, + get_color(measurement["time"]), + measurement["time"], + get_color(measurement["cpu"]), + measurement["cpu"], + measurement["real_time"], + measurement["real_time_other"], + measurement["cpu_time"], + measurement["cpu_time_other"], + endc=BC_ENDC, + ) + ] + + # After processing the measurements, if requested and + # if applicable (e.g. u-test exists for given benchmark), + # print the U test. + if utest and benchmark["utest"]: + output_strs += print_utest( + benchmark["name"], + benchmark["utest"], + utest_alpha=utest_alpha, + first_col_width=first_col_width, + use_color=use_color, + ) + + return output_strs + + +############################################################################### +# Unit tests + + +class TestGetUniqueBenchmarkNames(unittest.TestCase): + def load_results(self): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput = os.path.join(testInputs, "test3_run0.json") + with open(testOutput, "r") as f: + json = json.load(f) + return json + + def test_basic(self): + expect_lines = [ + "BM_One", + "BM_Two", + "short", # These two are not sorted + "medium", # These two are not sorted + ] + json = self.load_results() + output_lines = get_unique_benchmark_names(json) + print("\n") + print("\n".join(output_lines)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + self.assertEqual(expect_lines[i], output_lines[i]) + + +class TestReportDifference(unittest.TestCase): + @classmethod + def setUpClass(cls): + def load_results(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput1 = os.path.join(testInputs, "test1_run1.json") + testOutput2 = os.path.join(testInputs, "test1_run2.json") + with open(testOutput1, "r") as f: + json1 = json.load(f) + with open(testOutput2, "r") as f: + json2 = json.load(f) + return json1, json2 + + json1, json2 = load_results() + cls.json_diff_report = get_difference_report(json1, json2) + + def test_json_diff_report_pretty_printing(self): + expect_lines = [ + ["BM_SameTimes", "+0.0000", "+0.0000", "10", "10", "10", "10"], + ["BM_2xFaster", "-0.5000", "-0.5000", "50", "25", "50", "25"], + ["BM_2xSlower", "+1.0000", "+1.0000", "50", "100", "50", "100"], + [ + "BM_1PercentFaster", + "-0.0100", + "-0.0100", + "100", + "99", + "100", + "99", + ], + [ + "BM_1PercentSlower", + "+0.0100", + "+0.0100", + "100", + "101", + "100", + "101", + ], + [ + "BM_10PercentFaster", + "-0.1000", + "-0.1000", + "100", + "90", + "100", + "90", + ], + [ + "BM_10PercentSlower", + "+0.1000", + "+0.1000", + "100", + "110", + "100", + "110", + ], + [ + "BM_100xSlower", + "+99.0000", + "+99.0000", + "100", + "10000", + "100", + "10000", + ], + [ + "BM_100xFaster", + "-0.9900", + "-0.9900", + "10000", + "100", + "10000", + "100", + ], + [ + "BM_10PercentCPUToTime", + "+0.1000", + "-0.1000", + "100", + "110", + "100", + "90", + ], + ["BM_ThirdFaster", "-0.3333", "-0.3334", "100", "67", "100", "67"], + ["BM_NotBadTimeUnit", "-0.9000", "+0.2000", "0", "0", "0", "1"], + ["BM_hasLabel", "+0.0000", "+0.0000", "1", "1", "1", "1"], + ["OVERALL_GEOMEAN", "-0.8113", "-0.7779", "0", "0", "0", "0"], + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, use_color=False + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(len(parts), 7) + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report_output(self): + expected_output = [ + { + "name": "BM_SameTimes", + "label": "", + "measurements": [ + { + "time": 0.0000, + "cpu": 0.0000, + "real_time": 10, + "real_time_other": 10, + "cpu_time": 10, + "cpu_time_other": 10, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_2xFaster", + "label": "", + "measurements": [ + { + "time": -0.5000, + "cpu": -0.5000, + "real_time": 50, + "real_time_other": 25, + "cpu_time": 50, + "cpu_time_other": 25, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_2xSlower", + "label": "", + "measurements": [ + { + "time": 1.0000, + "cpu": 1.0000, + "real_time": 50, + "real_time_other": 100, + "cpu_time": 50, + "cpu_time_other": 100, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_1PercentFaster", + "label": "", + "measurements": [ + { + "time": -0.0100, + "cpu": -0.0100, + "real_time": 100, + "real_time_other": 98.9999999, + "cpu_time": 100, + "cpu_time_other": 98.9999999, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_1PercentSlower", + "label": "", + "measurements": [ + { + "time": 0.0100, + "cpu": 0.0100, + "real_time": 100, + "real_time_other": 101, + "cpu_time": 100, + "cpu_time_other": 101, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_10PercentFaster", + "label": "", + "measurements": [ + { + "time": -0.1000, + "cpu": -0.1000, + "real_time": 100, + "real_time_other": 90, + "cpu_time": 100, + "cpu_time_other": 90, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_10PercentSlower", + "label": "", + "measurements": [ + { + "time": 0.1000, + "cpu": 0.1000, + "real_time": 100, + "real_time_other": 110, + "cpu_time": 100, + "cpu_time_other": 110, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_100xSlower", + "label": "", + "measurements": [ + { + "time": 99.0000, + "cpu": 99.0000, + "real_time": 100, + "real_time_other": 10000, + "cpu_time": 100, + "cpu_time_other": 10000, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_100xFaster", + "label": "", + "measurements": [ + { + "time": -0.9900, + "cpu": -0.9900, + "real_time": 10000, + "real_time_other": 100, + "cpu_time": 10000, + "cpu_time_other": 100, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_10PercentCPUToTime", + "label": "", + "measurements": [ + { + "time": 0.1000, + "cpu": -0.1000, + "real_time": 100, + "real_time_other": 110, + "cpu_time": 100, + "cpu_time_other": 90, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_ThirdFaster", + "label": "", + "measurements": [ + { + "time": -0.3333, + "cpu": -0.3334, + "real_time": 100, + "real_time_other": 67, + "cpu_time": 100, + "cpu_time_other": 67, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_NotBadTimeUnit", + "label": "", + "measurements": [ + { + "time": -0.9000, + "cpu": 0.2000, + "real_time": 0.4, + "real_time_other": 0.04, + "cpu_time": 0.5, + "cpu_time_other": 0.6, + } + ], + "time_unit": "s", + "utest": {}, + }, + { + "name": "BM_hasLabel", + "label": "a label", + "measurements": [ + { + "time": 0.0000, + "cpu": 0.0000, + "real_time": 1, + "real_time_other": 1, + "cpu_time": 1, + "cpu_time_other": 1, + } + ], + "time_unit": "s", + "utest": {}, + }, + { + "name": "OVERALL_GEOMEAN", + "label": "", + "measurements": [ + { + "real_time": 3.1622776601683826e-06, + "cpu_time": 3.2130844755623912e-06, + "real_time_other": 1.9768988699420897e-07, + "cpu_time_other": 2.397447755209533e-07, + "time": -0.8112976497120911, + "cpu": -0.7778551721181174, + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + }, + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["label"], expected["label"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + assert_measurements(self, out, expected) + + +class TestReportDifferenceBetweenFamilies(unittest.TestCase): + @classmethod + def setUpClass(cls): + def load_result(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput = os.path.join(testInputs, "test2_run.json") + with open(testOutput, "r") as f: + json = json.load(f) + return json + + json = load_result() + json1 = filter_benchmark(json, "BM_Z.ro", ".") + json2 = filter_benchmark(json, "BM_O.e", ".") + cls.json_diff_report = get_difference_report(json1, json2) + + def test_json_diff_report_pretty_printing(self): + expect_lines = [ + [".", "-0.5000", "-0.5000", "10", "5", "10", "5"], + ["./4", "-0.5000", "-0.5000", "40", "20", "40", "20"], + ["Prefix/.", "-0.5000", "-0.5000", "20", "10", "20", "10"], + ["Prefix/./3", "-0.5000", "-0.5000", "30", "15", "30", "15"], + ["OVERALL_GEOMEAN", "-0.5000", "-0.5000", "0", "0", "0", "0"], + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, use_color=False + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(len(parts), 7) + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report(self): + expected_output = [ + { + "name": ".", + "measurements": [ + { + "time": -0.5, + "cpu": -0.5, + "real_time": 10, + "real_time_other": 5, + "cpu_time": 10, + "cpu_time_other": 5, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "./4", + "measurements": [ + { + "time": -0.5, + "cpu": -0.5, + "real_time": 40, + "real_time_other": 20, + "cpu_time": 40, + "cpu_time_other": 20, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "Prefix/.", + "measurements": [ + { + "time": -0.5, + "cpu": -0.5, + "real_time": 20, + "real_time_other": 10, + "cpu_time": 20, + "cpu_time_other": 10, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "Prefix/./3", + "measurements": [ + { + "time": -0.5, + "cpu": -0.5, + "real_time": 30, + "real_time_other": 15, + "cpu_time": 30, + "cpu_time_other": 15, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "OVERALL_GEOMEAN", + "measurements": [ + { + "real_time": 2.213363839400641e-08, + "cpu_time": 2.213363839400641e-08, + "real_time_other": 1.1066819197003185e-08, + "cpu_time_other": 1.1066819197003185e-08, + "time": -0.5000000000000009, + "cpu": -0.5000000000000009, + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + }, + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + assert_measurements(self, out, expected) + + +class TestReportDifferenceWithUTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + def load_results(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput1 = os.path.join(testInputs, "test3_run0.json") + testOutput2 = os.path.join(testInputs, "test3_run1.json") + with open(testOutput1, "r") as f: + json1 = json.load(f) + with open(testOutput2, "r") as f: + json2 = json.load(f) + return json1, json2 + + json1, json2 = load_results() + cls.json_diff_report = get_difference_report(json1, json2, utest=True) + + def test_json_diff_report_pretty_printing(self): + expect_lines = [ + ["BM_One", "-0.1000", "+0.1000", "10", "9", "100", "110"], + ["BM_Two", "+0.1111", "-0.0111", "9", "10", "90", "89"], + ["BM_Two", "-0.1250", "-0.1628", "8", "7", "86", "72"], + [ + "BM_Two_pvalue", + "1.0000", + "0.6667", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "2.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["short", "-0.1250", "-0.0625", "8", "7", "80", "75"], + ["short", "-0.4325", "-0.1351", "8", "5", "77", "67"], + [ + "short_pvalue", + "0.7671", + "0.2000", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "3.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["medium", "-0.3750", "-0.3375", "8", "5", "80", "53"], + ["OVERALL_GEOMEAN", "+1.6405", "-0.6985", "0", "0", "0", "0"], + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, utest=True, utest_alpha=0.05, use_color=False + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report_pretty_printing_aggregates_only(self): + expect_lines = [ + ["BM_One", "-0.1000", "+0.1000", "10", "9", "100", "110"], + [ + "BM_Two_pvalue", + "1.0000", + "0.6667", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "2.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["short", "-0.1250", "-0.0625", "8", "7", "80", "75"], + ["short", "-0.4325", "-0.1351", "8", "5", "77", "67"], + [ + "short_pvalue", + "0.7671", + "0.2000", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "3.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["OVERALL_GEOMEAN", "+1.6405", "-0.6985", "0", "0", "0", "0"], + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, + include_aggregates_only=True, + utest=True, + utest_alpha=0.05, + use_color=False, + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report(self): + expected_output = [ + { + "name": "BM_One", + "measurements": [ + { + "time": -0.1, + "cpu": 0.1, + "real_time": 10, + "real_time_other": 9, + "cpu_time": 100, + "cpu_time_other": 110, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_Two", + "measurements": [ + { + "time": 0.1111111111111111, + "cpu": -0.011111111111111112, + "real_time": 9, + "real_time_other": 10, + "cpu_time": 90, + "cpu_time_other": 89, + }, + { + "time": -0.125, + "cpu": -0.16279069767441862, + "real_time": 8, + "real_time_other": 7, + "cpu_time": 86, + "cpu_time_other": 72, + }, + ], + "time_unit": "ns", + "utest": { + "have_optimal_repetitions": False, + "cpu_pvalue": 0.6666666666666666, + "time_pvalue": 1.0, + }, + }, + { + "name": "short", + "measurements": [ + { + "time": -0.125, + "cpu": -0.0625, + "real_time": 8, + "real_time_other": 7, + "cpu_time": 80, + "cpu_time_other": 75, + }, + { + "time": -0.4325, + "cpu": -0.13506493506493514, + "real_time": 8, + "real_time_other": 4.54, + "cpu_time": 77, + "cpu_time_other": 66.6, + }, + ], + "time_unit": "ns", + "utest": { + "have_optimal_repetitions": False, + "cpu_pvalue": 0.2, + "time_pvalue": 0.7670968684102772, + }, + }, + { + "name": "medium", + "measurements": [ + { + "time": -0.375, + "cpu": -0.3375, + "real_time": 8, + "real_time_other": 5, + "cpu_time": 80, + "cpu_time_other": 53, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "OVERALL_GEOMEAN", + "measurements": [ + { + "real_time": 8.48528137423858e-09, + "cpu_time": 8.441336246629233e-08, + "real_time_other": 2.2405267593145244e-08, + "cpu_time_other": 2.5453661413660466e-08, + "time": 1.6404861082353634, + "cpu": -0.6984640740519662, + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + }, + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + assert_measurements(self, out, expected) + + +class TestReportDifferenceWithUTestWhileDisplayingAggregatesOnly( + unittest.TestCase +): + @classmethod + def setUpClass(cls): + def load_results(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput1 = os.path.join(testInputs, "test3_run0.json") + testOutput2 = os.path.join(testInputs, "test3_run1.json") + with open(testOutput1, "r") as f: + json1 = json.load(f) + with open(testOutput2, "r") as f: + json2 = json.load(f) + return json1, json2 + + json1, json2 = load_results() + cls.json_diff_report = get_difference_report(json1, json2, utest=True) + + def test_json_diff_report_pretty_printing(self): + expect_lines = [ + ["BM_One", "-0.1000", "+0.1000", "10", "9", "100", "110"], + ["BM_Two", "+0.1111", "-0.0111", "9", "10", "90", "89"], + ["BM_Two", "-0.1250", "-0.1628", "8", "7", "86", "72"], + [ + "BM_Two_pvalue", + "1.0000", + "0.6667", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "2.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["short", "-0.1250", "-0.0625", "8", "7", "80", "75"], + ["short", "-0.4325", "-0.1351", "8", "5", "77", "67"], + [ + "short_pvalue", + "0.7671", + "0.2000", + "U", + "Test,", + "Repetitions:", + "2", + "vs", + "3.", + "WARNING:", + "Results", + "unreliable!", + "9+", + "repetitions", + "recommended.", + ], + ["medium", "-0.3750", "-0.3375", "8", "5", "80", "53"], + ["OVERALL_GEOMEAN", "+1.6405", "-0.6985", "0", "0", "0", "0"], + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, utest=True, utest_alpha=0.05, use_color=False + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report(self): + expected_output = [ + { + "name": "BM_One", + "measurements": [ + { + "time": -0.1, + "cpu": 0.1, + "real_time": 10, + "real_time_other": 9, + "cpu_time": 100, + "cpu_time_other": 110, + } + ], + "time_unit": "ns", + "utest": {}, + }, + { + "name": "BM_Two", + "measurements": [ + { + "time": 0.1111111111111111, + "cpu": -0.011111111111111112, + "real_time": 9, + "real_time_other": 10, + "cpu_time": 90, + "cpu_time_other": 89, + }, + { + "time": -0.125, + "cpu": -0.16279069767441862, + "real_time": 8, + "real_time_other": 7, + "cpu_time": 86, + "cpu_time_other": 72, + }, + ], + "time_unit": "ns", + "utest": { + "have_optimal_repetitions": False, + "cpu_pvalue": 0.6666666666666666, + "time_pvalue": 1.0, + }, + }, + { + "name": "short", + "measurements": [ + { + "time": -0.125, + "cpu": -0.0625, + "real_time": 8, + "real_time_other": 7, + "cpu_time": 80, + "cpu_time_other": 75, + }, + { + "time": -0.4325, + "cpu": -0.13506493506493514, + "real_time": 8, + "real_time_other": 4.54, + "cpu_time": 77, + "cpu_time_other": 66.6, + }, + ], + "time_unit": "ns", + "utest": { + "have_optimal_repetitions": False, + "cpu_pvalue": 0.2, + "time_pvalue": 0.7670968684102772, + }, + }, + { + "name": "medium", + "measurements": [ + { + "real_time_other": 5, + "cpu_time": 80, + "time": -0.375, + "real_time": 8, + "cpu_time_other": 53, + "cpu": -0.3375, + } + ], + "utest": {}, + "time_unit": "ns", + "aggregate_name": "", + }, + { + "name": "OVERALL_GEOMEAN", + "measurements": [ + { + "real_time": 8.48528137423858e-09, + "cpu_time": 8.441336246629233e-08, + "real_time_other": 2.2405267593145244e-08, + "cpu_time_other": 2.5453661413660466e-08, + "time": 1.6404861082353634, + "cpu": -0.6984640740519662, + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + }, + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + assert_measurements(self, out, expected) + + +class TestReportDifferenceForPercentageAggregates(unittest.TestCase): + @classmethod + def setUpClass(cls): + def load_results(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput1 = os.path.join(testInputs, "test4_run0.json") + testOutput2 = os.path.join(testInputs, "test4_run1.json") + with open(testOutput1, "r") as f: + json1 = json.load(f) + with open(testOutput2, "r") as f: + json2 = json.load(f) + return json1, json2 + + json1, json2 = load_results() + cls.json_diff_report = get_difference_report(json1, json2, utest=True) + + def test_json_diff_report_pretty_printing(self): + expect_lines = [["whocares", "-0.5000", "+0.5000", "0", "0", "0", "0"]] + output_lines_with_header = print_difference_report( + self.json_diff_report, utest=True, utest_alpha=0.05, use_color=False + ) + output_lines = output_lines_with_header[2:] + print("\n") + print("\n".join(output_lines_with_header)) + self.assertEqual(len(output_lines), len(expect_lines)) + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + self.assertEqual(expect_lines[i], parts) + + def test_json_diff_report(self): + expected_output = [ + { + "name": "whocares", + "measurements": [ + { + "time": -0.5, + "cpu": 0.5, + "real_time": 0.01, + "real_time_other": 0.005, + "cpu_time": 0.10, + "cpu_time_other": 0.15, + } + ], + "time_unit": "ns", + "utest": {}, + } + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + assert_measurements(self, out, expected) + + +class TestReportSorting(unittest.TestCase): + @classmethod + def setUpClass(cls): + def load_result(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput = os.path.join(testInputs, "test4_run.json") + with open(testOutput, "r") as f: + json = json.load(f) + return json + + cls.json = load_result() + + def test_json_diff_report_pretty_printing(self): + import util + + expected_names = [ + "99 family 0 instance 0 repetition 0", + "98 family 0 instance 0 repetition 1", + "97 family 0 instance 0 aggregate", + "96 family 0 instance 1 repetition 0", + "95 family 0 instance 1 repetition 1", + "94 family 0 instance 1 aggregate", + "93 family 1 instance 0 repetition 0", + "92 family 1 instance 0 repetition 1", + "91 family 1 instance 0 aggregate", + "90 family 1 instance 1 repetition 0", + "89 family 1 instance 1 repetition 1", + "88 family 1 instance 1 aggregate", + ] + + for n in range(len(self.json["benchmarks"]) ** 2): + random.shuffle(self.json["benchmarks"]) + sorted_benchmarks = util.sort_benchmark_results(self.json)[ + "benchmarks" + ] + self.assertEqual(len(expected_names), len(sorted_benchmarks)) + for out, expected in zip(sorted_benchmarks, expected_names): + self.assertEqual(out["name"], expected) + + +class TestReportDifferenceWithUTestWhileDisplayingAggregatesOnly2( + unittest.TestCase +): + @classmethod + def setUpClass(cls): + def load_results(): + import json + + testInputs = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "Inputs" + ) + testOutput1 = os.path.join(testInputs, "test5_run0.json") + testOutput2 = os.path.join(testInputs, "test5_run1.json") + with open(testOutput1, "r") as f: + json1 = json.load(f) + json1["benchmarks"] = [ + json1["benchmarks"][0] for i in range(1000) + ] + with open(testOutput2, "r") as f: + json2 = json.load(f) + json2["benchmarks"] = [ + json2["benchmarks"][0] for i in range(1000) + ] + return json1, json2 + + json1, json2 = load_results() + cls.json_diff_report = get_difference_report(json1, json2, utest=True) + + def test_json_diff_report_pretty_printing(self): + expect_line = [ + "BM_ManyRepetitions_pvalue", + "0.0000", + "0.0000", + "U", + "Test,", + "Repetitions:", + "1000", + "vs", + "1000", + ] + output_lines_with_header = print_difference_report( + self.json_diff_report, utest=True, utest_alpha=0.05, use_color=False + ) + output_lines = output_lines_with_header[2:] + found = False + for i in range(0, len(output_lines)): + parts = [x for x in output_lines[i].split(" ") if x] + found = expect_line == parts + if found: + break + self.assertTrue(found) + + def test_json_diff_report(self): + expected_output = [ + { + "name": "BM_ManyRepetitions", + "label": "", + "time_unit": "s", + "run_type": "", + "aggregate_name": "", + "utest": { + "have_optimal_repetitions": True, + "cpu_pvalue": 0.0, + "time_pvalue": 0.0, + "nr_of_repetitions": 1000, + "nr_of_repetitions_other": 1000, + }, + }, + { + "name": "OVERALL_GEOMEAN", + "label": "", + "measurements": [ + { + "real_time": 1.0, + "cpu_time": 1000.000000000069, + "real_time_other": 1000.000000000069, + "cpu_time_other": 1.0, + "time": 999.000000000069, + "cpu": -0.9990000000000001, + } + ], + "time_unit": "s", + "run_type": "aggregate", + "aggregate_name": "geomean", + "utest": {}, + }, + ] + self.assertEqual(len(self.json_diff_report), len(expected_output)) + for out, expected in zip(self.json_diff_report, expected_output): + self.assertEqual(out["name"], expected["name"]) + self.assertEqual(out["time_unit"], expected["time_unit"]) + assert_utest(self, out, expected) + + +def assert_utest(unittest_instance, lhs, rhs): + if lhs["utest"]: + unittest_instance.assertAlmostEqual( + lhs["utest"]["cpu_pvalue"], rhs["utest"]["cpu_pvalue"] + ) + unittest_instance.assertAlmostEqual( + lhs["utest"]["time_pvalue"], rhs["utest"]["time_pvalue"] + ) + unittest_instance.assertEqual( + lhs["utest"]["have_optimal_repetitions"], + rhs["utest"]["have_optimal_repetitions"], + ) + else: + # lhs is empty. assert if rhs is not. + unittest_instance.assertEqual(lhs["utest"], rhs["utest"]) + + +def assert_measurements(unittest_instance, lhs, rhs): + for m1, m2 in zip(lhs["measurements"], rhs["measurements"]): + unittest_instance.assertEqual(m1["real_time"], m2["real_time"]) + unittest_instance.assertEqual(m1["cpu_time"], m2["cpu_time"]) + # m1['time'] and m1['cpu'] hold values which are being calculated, + # and therefore we must use almost-equal pattern. + unittest_instance.assertAlmostEqual(m1["time"], m2["time"], places=4) + unittest_instance.assertAlmostEqual(m1["cpu"], m2["cpu"], places=4) + + +if __name__ == "__main__": + unittest.main() + +# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 +# kate: tab-width: 4; replace-tabs on; indent-width 4; tab-indents: off; +# kate: indent-mode python; remove-trailing-spaces modified; diff --git a/third_party/benchmark/tools/gbench/util.py b/third_party/benchmark/tools/gbench/util.py new file mode 100644 index 0000000..1119a1a --- /dev/null +++ b/third_party/benchmark/tools/gbench/util.py @@ -0,0 +1,231 @@ +"""util.py - General utilities for running, loading, and processing benchmarks""" + +import json +import os +import re +import subprocess +import sys +import tempfile + +# Input file type enumeration +IT_Invalid = 0 +IT_JSON = 1 +IT_Executable = 2 + +_num_magic_bytes = 2 if sys.platform.startswith("win") else 4 + + +def is_executable_file(filename): + """ + Return 'True' if 'filename' names a valid file which is likely + an executable. A file is considered an executable if it starts with the + magic bytes for a EXE, Mach O, or ELF file. + """ + if not os.path.isfile(filename): + return False + with open(filename, mode="rb") as f: + magic_bytes = f.read(_num_magic_bytes) + if sys.platform == "darwin": + return magic_bytes in [ + b"\xfe\xed\xfa\xce", # MH_MAGIC + b"\xce\xfa\xed\xfe", # MH_CIGAM + b"\xfe\xed\xfa\xcf", # MH_MAGIC_64 + b"\xcf\xfa\xed\xfe", # MH_CIGAM_64 + b"\xca\xfe\xba\xbe", # FAT_MAGIC + b"\xbe\xba\xfe\xca", # FAT_CIGAM + ] + elif sys.platform.startswith("win"): + return magic_bytes == b"MZ" + else: + return magic_bytes == b"\x7fELF" + + +def is_json_file(filename): + """ + Returns 'True' if 'filename' names a valid JSON output file. + 'False' otherwise. + """ + try: + with open(filename, "r") as f: + json.load(f) + return True + except BaseException: + pass + return False + + +def classify_input_file(filename): + """ + Return a tuple (type, msg) where 'type' specifies the classified type + of 'filename'. If 'type' is 'IT_Invalid' then 'msg' is a human readable + string representing the error. + """ + ftype = IT_Invalid + err_msg = None + if not os.path.exists(filename): + err_msg = "'%s' does not exist" % filename + elif not os.path.isfile(filename): + err_msg = "'%s' does not name a file" % filename + elif is_executable_file(filename): + ftype = IT_Executable + elif is_json_file(filename): + ftype = IT_JSON + else: + err_msg = ( + "'%s' does not name a valid benchmark executable or JSON file" + % filename + ) + return ftype, err_msg + + +def check_input_file(filename): + """ + Classify the file named by 'filename' and return the classification. + If the file is classified as 'IT_Invalid' print an error message and exit + the program. + """ + ftype, msg = classify_input_file(filename) + if ftype == IT_Invalid: + print("Invalid input file: %s" % msg) + sys.exit(1) + return ftype + + +def find_benchmark_flag(prefix, benchmark_flags): + """ + Search the specified list of flags for a flag matching `` and + if it is found return the arg it specifies. If specified more than once the + last value is returned. If the flag is not found None is returned. + """ + assert prefix.startswith("--") and prefix.endswith("=") + result = None + for f in benchmark_flags: + if f.startswith(prefix): + result = f[len(prefix) :] + return result + + +def remove_benchmark_flags(prefix, benchmark_flags): + """ + Return a new list containing the specified benchmark_flags except those + with the specified prefix. + """ + assert prefix.startswith("--") and prefix.endswith("=") + return [f for f in benchmark_flags if not f.startswith(prefix)] + + +def load_benchmark_results(fname, benchmark_filter): + """ + Read benchmark output from a file and return the JSON object. + + Apply benchmark_filter, a regular expression, with nearly the same + semantics of the --benchmark_filter argument. May be None. + Note: the Python regular expression engine is used instead of the + one used by the C++ code, which may produce different results + in complex cases. + + REQUIRES: 'fname' names a file containing JSON benchmark output. + """ + + def benchmark_wanted(benchmark): + if benchmark_filter is None: + return True + name = benchmark.get("run_name", None) or benchmark["name"] + return re.search(benchmark_filter, name) is not None + + with open(fname, "r") as f: + results = json.load(f) + if "context" in results: + if "json_schema_version" in results["context"]: + json_schema_version = results["context"]["json_schema_version"] + if json_schema_version != 1: + print( + "In %s, got unnsupported JSON schema version: %i, expected 1" + % (fname, json_schema_version) + ) + sys.exit(1) + if "benchmarks" in results: + results["benchmarks"] = list( + filter(benchmark_wanted, results["benchmarks"]) + ) + return results + + +def sort_benchmark_results(result): + benchmarks = result["benchmarks"] + + # From inner key to the outer key! + benchmarks = sorted( + benchmarks, + key=lambda benchmark: benchmark["repetition_index"] + if "repetition_index" in benchmark + else -1, + ) + benchmarks = sorted( + benchmarks, + key=lambda benchmark: 1 + if "run_type" in benchmark and benchmark["run_type"] == "aggregate" + else 0, + ) + benchmarks = sorted( + benchmarks, + key=lambda benchmark: benchmark["per_family_instance_index"] + if "per_family_instance_index" in benchmark + else -1, + ) + benchmarks = sorted( + benchmarks, + key=lambda benchmark: benchmark["family_index"] + if "family_index" in benchmark + else -1, + ) + + result["benchmarks"] = benchmarks + return result + + +def run_benchmark(exe_name, benchmark_flags): + """ + Run a benchmark specified by 'exe_name' with the specified + 'benchmark_flags'. The benchmark is run directly as a subprocess to preserve + real time console output. + RETURNS: A JSON object representing the benchmark output + """ + output_name = find_benchmark_flag("--benchmark_out=", benchmark_flags) + is_temp_output = False + if output_name is None: + is_temp_output = True + thandle, output_name = tempfile.mkstemp() + os.close(thandle) + benchmark_flags = list(benchmark_flags) + [ + "--benchmark_out=%s" % output_name + ] + + cmd = [exe_name] + benchmark_flags + print("RUNNING: %s" % " ".join(cmd)) + exitCode = subprocess.call(cmd) + if exitCode != 0: + print("TEST FAILED...") + sys.exit(exitCode) + json_res = load_benchmark_results(output_name, None) + if is_temp_output: + os.unlink(output_name) + return json_res + + +def run_or_load_benchmark(filename, benchmark_flags): + """ + Get the results for a specified benchmark. If 'filename' specifies + an executable benchmark then the results are generated by running the + benchmark. Otherwise 'filename' must name a valid JSON output file, + which is loaded and the result returned. + """ + ftype = check_input_file(filename) + if ftype == IT_JSON: + benchmark_filter = find_benchmark_flag( + "--benchmark_filter=", benchmark_flags + ) + return load_benchmark_results(filename, benchmark_filter) + if ftype == IT_Executable: + return run_benchmark(filename, benchmark_flags) + raise ValueError("Unknown file type %s" % ftype) diff --git a/third_party/benchmark/tools/libpfm.BUILD.bazel b/third_party/benchmark/tools/libpfm.BUILD.bazel new file mode 100644 index 0000000..6269534 --- /dev/null +++ b/third_party/benchmark/tools/libpfm.BUILD.bazel @@ -0,0 +1,22 @@ +# Build rule for libpfm, which is required to collect performance counters for +# BENCHMARK_ENABLE_LIBPFM builds. + +load("@rules_foreign_cc//foreign_cc:defs.bzl", "make") + +filegroup( + name = "pfm_srcs", + srcs = glob(["**"]), +) + +make( + name = "libpfm", + lib_source = ":pfm_srcs", + lib_name = "libpfm", + copts = [ + "-Wno-format-truncation", + "-Wno-use-after-free", + ], + visibility = [ + "//visibility:public", + ], +) diff --git a/third_party/benchmark/tools/requirements.txt b/third_party/benchmark/tools/requirements.txt new file mode 100644 index 0000000..f32f35b --- /dev/null +++ b/third_party/benchmark/tools/requirements.txt @@ -0,0 +1,2 @@ +numpy == 1.25 +scipy == 1.10.0 diff --git a/third_party/benchmark/tools/strip_asm.py b/third_party/benchmark/tools/strip_asm.py new file mode 100755 index 0000000..bc3a774 --- /dev/null +++ b/third_party/benchmark/tools/strip_asm.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 + +""" +strip_asm.py - Cleanup ASM output for the specified file +""" + +import os +import re +import sys +from argparse import ArgumentParser + + +def find_used_labels(asm): + found = set() + label_re = re.compile(r"\s*j[a-z]+\s+\.L([a-zA-Z0-9][a-zA-Z0-9_]*)") + for line in asm.splitlines(): + m = label_re.match(line) + if m: + found.add(".L%s" % m.group(1)) + return found + + +def normalize_labels(asm): + decls = set() + label_decl = re.compile("^[.]{0,1}L([a-zA-Z0-9][a-zA-Z0-9_]*)(?=:)") + for line in asm.splitlines(): + m = label_decl.match(line) + if m: + decls.add(m.group(0)) + if len(decls) == 0: + return asm + needs_dot = next(iter(decls))[0] != "." + if not needs_dot: + return asm + for ld in decls: + asm = re.sub(r"(^|\s+)" + ld + r"(?=:|\s)", "\\1." + ld, asm) + return asm + + +def transform_labels(asm): + asm = normalize_labels(asm) + used_decls = find_used_labels(asm) + new_asm = "" + label_decl = re.compile(r"^\.L([a-zA-Z0-9][a-zA-Z0-9_]*)(?=:)") + for line in asm.splitlines(): + m = label_decl.match(line) + if not m or m.group(0) in used_decls: + new_asm += line + new_asm += "\n" + return new_asm + + +def is_identifier(tk): + if len(tk) == 0: + return False + first = tk[0] + if not first.isalpha() and first != "_": + return False + for i in range(1, len(tk)): + c = tk[i] + if not c.isalnum() and c != "_": + return False + return True + + +def process_identifiers(line): + """ + process_identifiers - process all identifiers and modify them to have + consistent names across all platforms; specifically across ELF and MachO. + For example, MachO inserts an additional understore at the beginning of + names. This function removes that. + """ + parts = re.split(r"([a-zA-Z0-9_]+)", line) + new_line = "" + for tk in parts: + if is_identifier(tk): + if tk.startswith("__Z"): + tk = tk[1:] + elif ( + tk.startswith("_") + and len(tk) > 1 + and tk[1].isalpha() + and tk[1] != "Z" + ): + tk = tk[1:] + new_line += tk + return new_line + + +def process_asm(asm): + """ + Strip the ASM of unwanted directives and lines + """ + new_contents = "" + asm = transform_labels(asm) + + # TODO: Add more things we want to remove + discard_regexes = [ + re.compile(r"\s+\..*$"), # directive + re.compile(r"\s*#(NO_APP|APP)$"), # inline ASM + re.compile(r"\s*#.*$"), # comment line + re.compile( + r"\s*\.globa?l\s*([.a-zA-Z_][a-zA-Z0-9$_.]*)" + ), # global directive + re.compile( + r"\s*\.(string|asciz|ascii|[1248]?byte|short|word|long|quad|value|zero)" + ), + ] + keep_regexes: list[re.Pattern] = [] + fn_label_def = re.compile("^[a-zA-Z_][a-zA-Z0-9_.]*:") + for line in asm.splitlines(): + # Remove Mach-O attribute + line = line.replace("@GOTPCREL", "") + add_line = True + for reg in discard_regexes: + if reg.match(line) is not None: + add_line = False + break + for reg in keep_regexes: + if reg.match(line) is not None: + add_line = True + break + if add_line: + if fn_label_def.match(line) and len(new_contents) != 0: + new_contents += "\n" + line = process_identifiers(line) + new_contents += line + new_contents += "\n" + return new_contents + + +def main(): + parser = ArgumentParser(description="generate a stripped assembly file") + parser.add_argument( + "input", + metavar="input", + type=str, + nargs=1, + help="An input assembly file", + ) + parser.add_argument( + "out", metavar="output", type=str, nargs=1, help="The output file" + ) + args, unknown_args = parser.parse_known_args() + input = args.input[0] + output = args.out[0] + if not os.path.isfile(input): + print("ERROR: input file '%s' does not exist" % input) + sys.exit(1) + + with open(input, "r") as f: + contents = f.read() + new_contents = process_asm(contents) + with open(output, "w") as f: + f.write(new_contents) + + +if __name__ == "__main__": + main() + +# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 +# kate: tab-width: 4; replace-tabs on; indent-width 4; tab-indents: off; +# kate: indent-mode python; remove-trailing-spaces modified; diff --git a/third_party/googletest/.clang-format b/third_party/googletest/.clang-format new file mode 100644 index 0000000..5b9bfe6 --- /dev/null +++ b/third_party/googletest/.clang-format @@ -0,0 +1,4 @@ +# Run manually to reformat a file: +# clang-format -i --style=file +Language: Cpp +BasedOnStyle: Google diff --git a/third_party/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml b/third_party/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml new file mode 100644 index 0000000..586779a --- /dev/null +++ b/third_party/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml @@ -0,0 +1,53 @@ +name: Bug Report +description: Let us know that something does not work as expected. +title: "[Bug]: Please title this bug report" +body: + - type: textarea + id: what-happened + attributes: + label: Describe the issue + description: What happened, and what did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: It is important that we are able to reproduce the problem that you are experiencing. Please provide all code and relevant steps to reproduce the problem, including your `BUILD`/`CMakeLists.txt` file and build commands. Links to a GitHub branch or [godbolt.org](https://godbolt.org/) that demonstrate the problem are also helpful. + validations: + required: true + - type: textarea + id: version + attributes: + label: What version of GoogleTest are you using? + description: Please include the output of `git rev-parse HEAD` or the GoogleTest release version number that you are using. + validations: + required: true + - type: textarea + id: os + attributes: + label: What operating system and version are you using? + description: If you are using a Linux distribution please include the name and version of the distribution as well. + validations: + required: true + - type: textarea + id: compiler + attributes: + label: What compiler and version are you using? + description: Please include the output of `gcc -v` or `clang -v`, or the equivalent for your compiler. + validations: + required: true + - type: textarea + id: buildsystem + attributes: + label: What build system are you using? + description: Please include the output of `bazel --version` or `cmake --version`, or the equivalent for your build system. + validations: + required: true + - type: textarea + id: additional + attributes: + label: Additional context + description: Add any other context about the problem here. + validations: + required: false diff --git a/third_party/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml b/third_party/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml new file mode 100644 index 0000000..f3bbc09 --- /dev/null +++ b/third_party/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml @@ -0,0 +1,33 @@ +name: Feature request +description: Propose a new feature. +title: "[FR]: Please title this feature request" +labels: "enhancement" +body: + - type: textarea + id: version + attributes: + label: Does the feature exist in the most recent commit? + description: We recommend using the latest commit from GitHub in your projects. + validations: + required: true + - type: textarea + id: why + attributes: + label: Why do we need this feature? + description: Ideally, explain why a combination of existing features cannot be used instead. + validations: + required: true + - type: textarea + id: proposal + attributes: + label: Describe the proposal. + description: Include a detailed description of the feature, with usage examples. + validations: + required: true + - type: textarea + id: platform + attributes: + label: Is the feature specific to an operating system, compiler, or build system version? + description: If it is, please specify which versions. + validations: + required: true diff --git a/third_party/googletest/.github/ISSUE_TEMPLATE/config.yml b/third_party/googletest/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..65170d1 --- /dev/null +++ b/third_party/googletest/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Get Help + url: https://github.com/google/googletest/discussions + about: Please ask and answer questions here. diff --git a/third_party/googletest/.github/workflows/gtest-ci.yml b/third_party/googletest/.github/workflows/gtest-ci.yml new file mode 100644 index 0000000..03a8cc5 --- /dev/null +++ b/third_party/googletest/.github/workflows/gtest-ci.yml @@ -0,0 +1,43 @@ +name: ci + +on: + push: + pull_request: + +env: + BAZEL_CXXOPTS: -std=c++14 + +jobs: + Linux: + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=-std=c++14 --features=external_include_paths --test_output=errors ... + + macOS: + runs-on: macos-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=-std=c++14 --features=external_include_paths --test_output=errors ... + + + Windows: + runs-on: windows-latest + steps: + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Tests + run: bazel test --cxxopt=/std:c++14 --features=external_include_paths --test_output=errors ... diff --git a/third_party/googletest/.gitignore b/third_party/googletest/.gitignore new file mode 100644 index 0000000..fede02f --- /dev/null +++ b/third_party/googletest/.gitignore @@ -0,0 +1,88 @@ +# Ignore CI build directory +build/ +xcuserdata +cmake-build-debug/ +.idea/ +bazel-bin +bazel-genfiles +bazel-googletest +bazel-out +bazel-testlogs +# python +*.pyc + +# Visual Studio files +.vs +*.sdf +*.opensdf +*.VC.opendb +*.suo +*.user +_ReSharper.Caches/ +Win32-Debug/ +Win32-Release/ +x64-Debug/ +x64-Release/ + +# VSCode files +.cache/ +cmake-variants.yaml + +# Ignore autoconf / automake files +Makefile.in +aclocal.m4 +configure +build-aux/ +autom4te.cache/ +googletest/m4/libtool.m4 +googletest/m4/ltoptions.m4 +googletest/m4/ltsugar.m4 +googletest/m4/ltversion.m4 +googletest/m4/lt~obsolete.m4 +googlemock/m4 + +# Ignore generated directories. +googlemock/fused-src/ +googletest/fused-src/ + +# macOS files +.DS_Store +googletest/.DS_Store +googletest/xcode/.DS_Store + +# Ignore cmake generated directories and files. +CMakeFiles +CTestTestfile.cmake +Makefile +cmake_install.cmake +googlemock/CMakeFiles +googlemock/CTestTestfile.cmake +googlemock/Makefile +googlemock/cmake_install.cmake +googlemock/gtest +/bin +/googlemock/gmock.dir +/googlemock/gmock_main.dir +/googlemock/RUN_TESTS.vcxproj.filters +/googlemock/RUN_TESTS.vcxproj +/googlemock/INSTALL.vcxproj.filters +/googlemock/INSTALL.vcxproj +/googlemock/gmock_main.vcxproj.filters +/googlemock/gmock_main.vcxproj +/googlemock/gmock.vcxproj.filters +/googlemock/gmock.vcxproj +/googlemock/gmock.sln +/googlemock/ALL_BUILD.vcxproj.filters +/googlemock/ALL_BUILD.vcxproj +/lib +/Win32 +/ZERO_CHECK.vcxproj.filters +/ZERO_CHECK.vcxproj +/RUN_TESTS.vcxproj.filters +/RUN_TESTS.vcxproj +/INSTALL.vcxproj.filters +/INSTALL.vcxproj +/googletest-distribution.sln +/CMakeCache.txt +/ALL_BUILD.vcxproj.filters +/ALL_BUILD.vcxproj diff --git a/third_party/googletest/BUILD.bazel b/third_party/googletest/BUILD.bazel new file mode 100644 index 0000000..b1e3b7f --- /dev/null +++ b/third_party/googletest/BUILD.bazel @@ -0,0 +1,219 @@ +# Copyright 2017 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Bazel Build for Google C++ Testing Framework(Google Test) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +config_setting( + name = "qnx", + constraint_values = ["@platforms//os:qnx"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], +) + +config_setting( + name = "freebsd", + constraint_values = ["@platforms//os:freebsd"], +) + +config_setting( + name = "openbsd", + constraint_values = ["@platforms//os:openbsd"], +) + +config_setting( + name = "msvc_compiler", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "msvc-cl", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "has_absl", + values = {"define": "absl=1"}, +) + +# Library that defines the FRIEND_TEST macro. +cc_library( + name = "gtest_prod", + hdrs = ["googletest/include/gtest/gtest_prod.h"], + includes = ["googletest/include"], +) + +# Google Test including Google Mock +cc_library( + name = "gtest", + srcs = glob( + include = [ + "googletest/src/*.cc", + "googletest/src/*.h", + "googletest/include/gtest/**/*.h", + "googlemock/src/*.cc", + "googlemock/include/gmock/**/*.h", + ], + exclude = [ + "googletest/src/gtest-all.cc", + "googletest/src/gtest_main.cc", + "googlemock/src/gmock-all.cc", + "googlemock/src/gmock_main.cc", + ], + ), + hdrs = glob([ + "googletest/include/gtest/*.h", + "googlemock/include/gmock/*.h", + ]), + copts = select({ + ":qnx": [], + ":windows": [], + "//conditions:default": ["-pthread"], + }), + defines = select({ + ":has_absl": ["GTEST_HAS_ABSL=1"], + "//conditions:default": [], + }), + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + includes = [ + "googlemock", + "googlemock/include", + "googletest", + "googletest/include", + ], + linkopts = select({ + ":qnx": ["-lregex"], + ":windows": [], + ":freebsd": [ + "-lm", + "-pthread", + ], + ":openbsd": [ + "-lm", + "-pthread", + ], + "//conditions:default": ["-pthread"], + }), + deps = select({ + ":has_absl": [ + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/debugging:failure_signal_handler", + "@com_google_absl//absl/debugging:stacktrace", + "@com_google_absl//absl/debugging:symbolize", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/flags:reflection", + "@com_google_absl//absl/flags:usage", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@com_googlesource_code_re2//:re2", + ], + "//conditions:default": [], + }), +) + +cc_library( + name = "gtest_main", + srcs = ["googlemock/src/gmock_main.cc"], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + deps = [":gtest"], +) + +# The following rules build samples of how to use gTest. +cc_library( + name = "gtest_sample_lib", + srcs = [ + "googletest/samples/sample1.cc", + "googletest/samples/sample2.cc", + "googletest/samples/sample4.cc", + ], + hdrs = [ + "googletest/samples/prime_tables.h", + "googletest/samples/sample1.h", + "googletest/samples/sample2.h", + "googletest/samples/sample3-inl.h", + "googletest/samples/sample4.h", + ], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), +) + +cc_test( + name = "gtest_samples", + size = "small", + # All Samples except: + # sample9 (main) + # sample10 (main and takes a command line option and needs to be separate) + srcs = [ + "googletest/samples/sample1_unittest.cc", + "googletest/samples/sample2_unittest.cc", + "googletest/samples/sample3_unittest.cc", + "googletest/samples/sample4_unittest.cc", + "googletest/samples/sample5_unittest.cc", + "googletest/samples/sample6_unittest.cc", + "googletest/samples/sample7_unittest.cc", + "googletest/samples/sample8_unittest.cc", + ], + linkstatic = 0, + deps = [ + "gtest_sample_lib", + ":gtest_main", + ], +) + +cc_test( + name = "sample9_unittest", + size = "small", + srcs = ["googletest/samples/sample9_unittest.cc"], + deps = [":gtest"], +) + +cc_test( + name = "sample10_unittest", + size = "small", + srcs = ["googletest/samples/sample10_unittest.cc"], + deps = [":gtest"], +) diff --git a/third_party/googletest/CMakeLists.txt b/third_party/googletest/CMakeLists.txt new file mode 100644 index 0000000..089ac98 --- /dev/null +++ b/third_party/googletest/CMakeLists.txt @@ -0,0 +1,27 @@ +# Note: CMake support is community-based. The maintainers do not use CMake +# internally. + +cmake_minimum_required(VERSION 3.13) + +project(googletest-distribution) +set(GOOGLETEST_VERSION 1.14.0) + +if(NOT CYGWIN AND NOT MSYS AND NOT ${CMAKE_SYSTEM_NAME} STREQUAL QNX) + set(CMAKE_CXX_EXTENSIONS OFF) +endif() + +enable_testing() + +include(CMakeDependentOption) +include(GNUInstallDirs) + +#Note that googlemock target already builds googletest +option(BUILD_GMOCK "Builds the googlemock subproject" ON) +option(INSTALL_GTEST "Enable installation of googletest. (Projects embedding googletest may want to turn this OFF.)" ON) +option(GTEST_HAS_ABSL "Use Abseil and RE2. Requires Abseil and RE2 to be separately added to the build." OFF) + +if(BUILD_GMOCK) + add_subdirectory( googlemock ) +else() + add_subdirectory( googletest ) +endif() diff --git a/third_party/googletest/CONTRIBUTING.md b/third_party/googletest/CONTRIBUTING.md new file mode 100644 index 0000000..8bed14b --- /dev/null +++ b/third_party/googletest/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# How to become a contributor and submit your own code + +## Contributor License Agreements + +We'd love to accept your patches! Before we can take them, we have to jump a +couple of legal hurdles. + +Please fill out either the individual or corporate Contributor License Agreement +(CLA). + +* If you are an individual writing original source code and you're sure you + own the intellectual property, then you'll need to sign an + [individual CLA](https://developers.google.com/open-source/cla/individual). +* If you work for a company that wants to allow you to contribute your work, + then you'll need to sign a + [corporate CLA](https://developers.google.com/open-source/cla/corporate). + +Follow either of the two links above to access the appropriate CLA and +instructions for how to sign and return it. Once we receive it, we'll be able to +accept your pull requests. + +## Are you a Googler? + +If you are a Googler, please make an attempt to submit an internal contribution +rather than a GitHub Pull Request. If you are not able to submit internally, a +PR is acceptable as an alternative. + +## Contributing A Patch + +1. Submit an issue describing your proposed change to the + [issue tracker](https://github.com/google/googletest/issues). +2. Please don't mix more than one logical change per submittal, because it + makes the history hard to follow. If you want to make a change that doesn't + have a corresponding issue in the issue tracker, please create one. +3. Also, coordinate with team members that are listed on the issue in question. + This ensures that work isn't being duplicated and communicating your plan + early also generally leads to better patches. +4. If your proposed change is accepted, and you haven't already done so, sign a + Contributor License Agreement + ([see details above](#contributor-license-agreements)). +5. Fork the desired repo, develop and test your code changes. +6. Ensure that your code adheres to the existing style in the sample to which + you are contributing. +7. Ensure that your code has an appropriate set of unit tests which all pass. +8. Submit a pull request. + +## The Google Test and Google Mock Communities + +The Google Test community exists primarily through the +[discussion group](http://groups.google.com/group/googletestframework) and the +GitHub repository. Likewise, the Google Mock community exists primarily through +their own [discussion group](http://groups.google.com/group/googlemock). You are +definitely encouraged to contribute to the discussion and you can also help us +to keep the effectiveness of the group high by following and promoting the +guidelines listed here. + +### Please Be Friendly + +Showing courtesy and respect to others is a vital part of the Google culture, +and we strongly encourage everyone participating in Google Test development to +join us in accepting nothing less. Of course, being courteous is not the same as +failing to constructively disagree with each other, but it does mean that we +should be respectful of each other when enumerating the 42 technical reasons +that a particular proposal may not be the best choice. There's never a reason to +be antagonistic or dismissive toward anyone who is sincerely trying to +contribute to a discussion. + +Sure, C++ testing is serious business and all that, but it's also a lot of fun. +Let's keep it that way. Let's strive to be one of the friendliest communities in +all of open source. + +As always, discuss Google Test in the official GoogleTest discussion group. You +don't have to actually submit code in order to sign up. Your participation +itself is a valuable contribution. + +## Style + +To keep the source consistent, readable, diffable and easy to merge, we use a +fairly rigid coding style, as defined by the +[google-styleguide](https://github.com/google/styleguide) project. All patches +will be expected to conform to the style outlined +[here](https://google.github.io/styleguide/cppguide.html). Use +[.clang-format](https://github.com/google/googletest/blob/main/.clang-format) to +check your formatting. + +## Requirements for Contributors + +If you plan to contribute a patch, you need to build Google Test, Google Mock, +and their own tests from a git checkout, which has further requirements: + +* [Python](https://www.python.org/) v3.6 or newer (for running some of the + tests and re-generating certain source files from templates) +* [CMake](https://cmake.org/) v2.8.12 or newer + +## Developing Google Test and Google Mock + +This section discusses how to make your own changes to the Google Test project. + +### Testing Google Test and Google Mock Themselves + +To make sure your changes work as intended and don't break existing +functionality, you'll want to compile and run Google Test and GoogleMock's own +tests. For that you can use CMake: + +``` +mkdir mybuild +cd mybuild +cmake -Dgtest_build_tests=ON -Dgmock_build_tests=ON ${GTEST_REPO_DIR} +``` + +To choose between building only Google Test or Google Mock, you may modify your +cmake command to be one of each + +``` +cmake -Dgtest_build_tests=ON ${GTEST_DIR} # sets up Google Test tests +cmake -Dgmock_build_tests=ON ${GMOCK_DIR} # sets up Google Mock tests +``` + +Make sure you have Python installed, as some of Google Test's tests are written +in Python. If the cmake command complains about not being able to find Python +(`Could NOT find PythonInterp (missing: PYTHON_EXECUTABLE)`), try telling it +explicitly where your Python executable can be found: + +``` +cmake -DPYTHON_EXECUTABLE=path/to/python ... +``` + +Next, you can build Google Test and / or Google Mock and all desired tests. On +\*nix, this is usually done by + +``` +make +``` + +To run the tests, do + +``` +make test +``` + +All tests should pass. diff --git a/third_party/googletest/CONTRIBUTORS b/third_party/googletest/CONTRIBUTORS new file mode 100644 index 0000000..77397a5 --- /dev/null +++ b/third_party/googletest/CONTRIBUTORS @@ -0,0 +1,65 @@ +# This file contains a list of people who've made non-trivial +# contribution to the Google C++ Testing Framework project. People +# who commit code to the project are encouraged to add their names +# here. Please keep the list sorted by first names. + +Ajay Joshi +Balázs Dán +Benoit Sigoure +Bharat Mediratta +Bogdan Piloca +Chandler Carruth +Chris Prince +Chris Taylor +Dan Egnor +Dave MacLachlan +David Anderson +Dean Sturtevant +Eric Roman +Gene Volovich +Hady Zalek +Hal Burch +Jeffrey Yasskin +Jim Keller +Joe Walnes +Jon Wray +Jói Sigurðsson +Keir Mierle +Keith Ray +Kenton Varda +Kostya Serebryany +Krystian Kuzniarek +Lev Makhlis +Manuel Klimek +Mario Tanev +Mark Paskin +Markus Heule +Martijn Vels +Matthew Simmons +Mika Raento +Mike Bland +Miklós Fazekas +Neal Norwitz +Nermin Ozkiranartli +Owen Carlsen +Paneendra Ba +Pasi Valminen +Patrick Hanna +Patrick Riley +Paul Menage +Peter Kaminski +Piotr Kaminski +Preston Jackson +Rainer Klaffenboeck +Russ Cox +Russ Rufer +Sean Mcafee +Sigurður Ásgeirsson +Sverre Sundsdal +Szymon Sobik +Takeshi Yoshino +Tracy Bialik +Vadim Berman +Vlad Losev +Wolfgang Klier +Zhanyong Wan diff --git a/third_party/googletest/LICENSE b/third_party/googletest/LICENSE new file mode 100644 index 0000000..1941a11 --- /dev/null +++ b/third_party/googletest/LICENSE @@ -0,0 +1,28 @@ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/googletest/README.md b/third_party/googletest/README.md new file mode 100644 index 0000000..443e020 --- /dev/null +++ b/third_party/googletest/README.md @@ -0,0 +1,146 @@ +# GoogleTest + +### Announcements + +#### Live at Head + +GoogleTest now follows the +[Abseil Live at Head philosophy](https://abseil.io/about/philosophy#upgrade-support). +We recommend +[updating to the latest commit in the `main` branch as often as possible](https://github.com/abseil/abseil-cpp/blob/master/FAQ.md#what-is-live-at-head-and-how-do-i-do-it). +We do publish occasional semantic versions, tagged with +`v${major}.${minor}.${patch}` (e.g. `v1.13.0`). + +#### Documentation Updates + +Our documentation is now live on GitHub Pages at +https://google.github.io/googletest/. We recommend browsing the documentation on +GitHub Pages rather than directly in the repository. + +#### Release 1.13.0 + +[Release 1.13.0](https://github.com/google/googletest/releases/tag/v1.13.0) is +now available. + +The 1.13.x branch requires at least C++14. + +#### Continuous Integration + +We use Google's internal systems for continuous integration. \ +GitHub Actions were added for the convenience of open-source contributors. They +are exclusively maintained by the open-source community and not used by the +GoogleTest team. + +#### Coming Soon + +* We are planning to take a dependency on + [Abseil](https://github.com/abseil/abseil-cpp). +* More documentation improvements are planned. + +## Welcome to **GoogleTest**, Google's C++ test framework! + +This repository is a merger of the formerly separate GoogleTest and GoogleMock +projects. These were so closely related that it makes sense to maintain and +release them together. + +### Getting Started + +See the [GoogleTest User's Guide](https://google.github.io/googletest/) for +documentation. We recommend starting with the +[GoogleTest Primer](https://google.github.io/googletest/primer.html). + +More information about building GoogleTest can be found at +[googletest/README.md](googletest/README.md). + +## Features + +* xUnit test framework: \ + Googletest is based on the [xUnit](https://en.wikipedia.org/wiki/XUnit) + testing framework, a popular architecture for unit testing +* Test discovery: \ + Googletest automatically discovers and runs your tests, eliminating the need + to manually register your tests +* Rich set of assertions: \ + Googletest provides a variety of assertions, such as equality, inequality, + exceptions, and more, making it easy to test your code +* User-defined assertions: \ + You can define your own assertions with Googletest, making it simple to + write tests that are specific to your code +* Death tests: \ + Googletest supports death tests, which verify that your code exits in a + certain way, making it useful for testing error-handling code +* Fatal and non-fatal failures: \ + You can specify whether a test failure should be treated as fatal or + non-fatal with Googletest, allowing tests to continue running even if a + failure occurs +* Value-parameterized tests: \ + Googletest supports value-parameterized tests, which run multiple times with + different input values, making it useful for testing functions that take + different inputs +* Type-parameterized tests: \ + Googletest also supports type-parameterized tests, which run with different + data types, making it useful for testing functions that work with different + data types +* Various options for running tests: \ + Googletest provides many options for running tests including running + individual tests, running tests in a specific order and running tests in + parallel + +## Supported Platforms + +GoogleTest follows Google's +[Foundational C++ Support Policy](https://opensource.google/documentation/policies/cplusplus-support). +See +[this table](https://github.com/google/oss-policies-info/blob/main/foundational-cxx-support-matrix.md) +for a list of currently supported versions of compilers, platforms, and build +tools. + +## Who Is Using GoogleTest? + +In addition to many internal projects at Google, GoogleTest is also used by the +following notable projects: + +* The [Chromium projects](http://www.chromium.org/) (behind the Chrome browser + and Chrome OS). +* The [LLVM](http://llvm.org/) compiler. +* [Protocol Buffers](https://github.com/google/protobuf), Google's data + interchange format. +* The [OpenCV](http://opencv.org/) computer vision library. + +## Related Open Source Projects + +[GTest Runner](https://github.com/nholthaus/gtest-runner) is a Qt5 based +automated test-runner and Graphical User Interface with powerful features for +Windows and Linux platforms. + +[GoogleTest UI](https://github.com/ospector/gtest-gbar) is a test runner that +runs your test binary, allows you to track its progress via a progress bar, and +displays a list of test failures. Clicking on one shows failure text. GoogleTest +UI is written in C#. + +[GTest TAP Listener](https://github.com/kinow/gtest-tap-listener) is an event +listener for GoogleTest that implements the +[TAP protocol](https://en.wikipedia.org/wiki/Test_Anything_Protocol) for test +result output. If your test runner understands TAP, you may find it useful. + +[gtest-parallel](https://github.com/google/gtest-parallel) is a test runner that +runs tests from your binary in parallel to provide significant speed-up. + +[GoogleTest Adapter](https://marketplace.visualstudio.com/items?itemName=DavidSchuldenfrei.gtest-adapter) +is a VS Code extension allowing to view GoogleTest in a tree view and run/debug +your tests. + +[C++ TestMate](https://github.com/matepek/vscode-catch2-test-adapter) is a VS +Code extension allowing to view GoogleTest in a tree view and run/debug your +tests. + +[Cornichon](https://pypi.org/project/cornichon/) is a small Gherkin DSL parser +that generates stub code for GoogleTest. + +## Contributing Changes + +Please read +[`CONTRIBUTING.md`](https://github.com/google/googletest/blob/main/CONTRIBUTING.md) +for details on how to contribute to this project. + +Happy testing! diff --git a/third_party/googletest/WORKSPACE b/third_party/googletest/WORKSPACE new file mode 100644 index 0000000..f819ffe --- /dev/null +++ b/third_party/googletest/WORKSPACE @@ -0,0 +1,27 @@ +workspace(name = "com_google_googletest") + +load("//:googletest_deps.bzl", "googletest_deps") +googletest_deps() + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "rules_python", # 2023-07-31T20:39:27Z + sha256 = "1250b59a33c591a1c4ba68c62e95fc88a84c334ec35a2e23f46cbc1b9a5a8b55", + strip_prefix = "rules_python-e355becc30275939d87116a4ec83dad4bb50d9e1", + urls = ["https://github.com/bazelbuild/rules_python/archive/e355becc30275939d87116a4ec83dad4bb50d9e1.zip"], +) + +http_archive( + name = "bazel_skylib", # 2023-05-31T19:24:07Z + sha256 = "08c0386f45821ce246bbbf77503c973246ed6ee5c3463e41efc197fa9bc3a7f4", + strip_prefix = "bazel-skylib-288731ef9f7f688932bd50e704a91a45ec185f9b", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/288731ef9f7f688932bd50e704a91a45ec185f9b.zip"], +) + +http_archive( + name = "platforms", # 2023-07-28T19:44:27Z + sha256 = "40eb313613ff00a5c03eed20aba58890046f4d38dec7344f00bb9a8867853526", + strip_prefix = "platforms-4ad40ef271da8176d4fc0194d2089b8a76e19d7b", + urls = ["https://github.com/bazelbuild/platforms/archive/4ad40ef271da8176d4fc0194d2089b8a76e19d7b.zip"], +) diff --git a/third_party/googletest/ci/linux-presubmit.sh b/third_party/googletest/ci/linux-presubmit.sh new file mode 100644 index 0000000..6bac887 --- /dev/null +++ b/third_party/googletest/ci/linux-presubmit.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +readonly LINUX_LATEST_CONTAINER="gcr.io/google.com/absl-177019/linux_hybrid-latest:20230217" +readonly LINUX_GCC_FLOOR_CONTAINER="gcr.io/google.com/absl-177019/linux_gcc-floor:20230120" + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +if [[ -z ${STD:-} ]]; then + STD="c++14 c++17 c++20" +fi + +# Test the CMake build +for cc in /usr/local/bin/gcc /opt/llvm/clang/bin/clang; do + for cmake_off_on in OFF ON; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --tmpfs="/build:exec" \ + --workdir="/build" \ + --rm \ + --env="CC=${cc}" \ + --env=CXXFLAGS="-Werror -Wdeprecated" \ + ${LINUX_LATEST_CONTAINER} \ + /bin/bash -c " + cmake /src \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} && \ + make -j$(nproc) && \ + ctest -j$(nproc) --output-on-failure" + done +done + +# Do one test with an older version of GCC +time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=c++14" \ + ${LINUX_GCC_FLOOR_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --copt="-Wno-error=pragmas" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + +# Test GCC +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --define="absl=${absl}" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + done +done + +# Test Clang +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/opt/llvm/clang/bin/clang" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="--gcc-toolchain=/usr/local" \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --define="absl=${absl}" \ + --distdir="/bazel-distdir" \ + --features=external_include_paths \ + --keep_going \ + --linkopt="--gcc-toolchain=/usr/local" \ + --show_timestamps \ + --test_output=errors + done +done diff --git a/third_party/googletest/ci/macos-presubmit.sh b/third_party/googletest/ci/macos-presubmit.sh new file mode 100644 index 0000000..681ebc2 --- /dev/null +++ b/third_party/googletest/ci/macos-presubmit.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +# Test the CMake build +for cmake_off_on in OFF ON; do + BUILD_DIR=$(mktemp -d build_dir.XXXXXXXX) + cd ${BUILD_DIR} + time cmake ${GTEST_ROOT} \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} + time make + time ctest -j$(nproc) --output-on-failure +done + +# Test the Bazel build + +# If we are running on Kokoro, check for a versioned Bazel binary. +KOKORO_GFILE_BAZEL_BIN="bazel-5.1.1-darwin-x86_64" +if [[ ${KOKORO_GFILE_DIR:-} ]] && [[ -f ${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN} ]]; then + BAZEL_BIN="${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN}" + chmod +x ${BAZEL_BIN} +else + BAZEL_BIN="bazel" +fi + +cd ${GTEST_ROOT} +for absl in 0 1; do + ${BAZEL_BIN} test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wundef" \ + --cxxopt="-std=c++14" \ + --define="absl=${absl}" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors +done diff --git a/third_party/googletest/ci/windows-presubmit.bat b/third_party/googletest/ci/windows-presubmit.bat new file mode 100644 index 0000000..48962eb --- /dev/null +++ b/third_party/googletest/ci/windows-presubmit.bat @@ -0,0 +1,58 @@ +SETLOCAL ENABLEDELAYEDEXPANSION + +SET BAZEL_EXE=%KOKORO_GFILE_DIR%\bazel-5.1.1-windows-x86_64.exe + +SET PATH=C:\Python34;%PATH% +SET BAZEL_PYTHON=C:\python34\python.exe +SET BAZEL_SH=C:\tools\msys64\usr\bin\bash.exe +SET CMAKE_BIN="cmake.exe" +SET CTEST_BIN="ctest.exe" +SET CTEST_OUTPUT_ON_FAILURE=1 +SET CMAKE_BUILD_PARALLEL_LEVEL=16 +SET CTEST_PARALLEL_LEVEL=16 + +IF EXIST git\googletest ( + CD git\googletest +) ELSE IF EXIST github\googletest ( + CD github\googletest +) + +IF %errorlevel% neq 0 EXIT /B 1 + +:: ---------------------------------------------------------------------------- +:: CMake +MKDIR cmake_msvc2022 +CD cmake_msvc2022 + +%CMAKE_BIN% .. ^ + -G "Visual Studio 17 2022" ^ + -DPYTHON_EXECUTABLE:FILEPATH=c:\python37\python.exe ^ + -DPYTHON_INCLUDE_DIR:PATH=c:\python37\include ^ + -DPYTHON_LIBRARY:FILEPATH=c:\python37\lib\site-packages\pip ^ + -Dgtest_build_samples=ON ^ + -Dgtest_build_tests=ON ^ + -Dgmock_build_tests=ON +IF %errorlevel% neq 0 EXIT /B 1 + +%CMAKE_BIN% --build . --target ALL_BUILD --config Debug -- -maxcpucount +IF %errorlevel% neq 0 EXIT /B 1 + +%CTEST_BIN% -C Debug --timeout 600 +IF %errorlevel% neq 0 EXIT /B 1 + +CD .. +RMDIR /S /Q cmake_msvc2022 + +:: ---------------------------------------------------------------------------- +:: Bazel + +SET BAZEL_VS=C:\Program Files\Microsoft Visual Studio\2022\Community +%BAZEL_EXE% test ... ^ + --compilation_mode=dbg ^ + --copt=/std:c++14 ^ + --copt=/WX ^ + --features=external_include_paths ^ + --keep_going ^ + --test_output=errors ^ + --test_tag_filters=-no_test_msvc2017 +IF %errorlevel% neq 0 EXIT /B 1 diff --git a/third_party/googletest/docs/_config.yml b/third_party/googletest/docs/_config.yml new file mode 100644 index 0000000..d12867e --- /dev/null +++ b/third_party/googletest/docs/_config.yml @@ -0,0 +1 @@ +title: GoogleTest diff --git a/third_party/googletest/docs/_data/navigation.yml b/third_party/googletest/docs/_data/navigation.yml new file mode 100644 index 0000000..9f33327 --- /dev/null +++ b/third_party/googletest/docs/_data/navigation.yml @@ -0,0 +1,43 @@ +nav: +- section: "Get Started" + items: + - title: "Supported Platforms" + url: "/platforms.html" + - title: "Quickstart: Bazel" + url: "/quickstart-bazel.html" + - title: "Quickstart: CMake" + url: "/quickstart-cmake.html" +- section: "Guides" + items: + - title: "GoogleTest Primer" + url: "/primer.html" + - title: "Advanced Topics" + url: "/advanced.html" + - title: "Mocking for Dummies" + url: "/gmock_for_dummies.html" + - title: "Mocking Cookbook" + url: "/gmock_cook_book.html" + - title: "Mocking Cheat Sheet" + url: "/gmock_cheat_sheet.html" +- section: "References" + items: + - title: "Testing Reference" + url: "/reference/testing.html" + - title: "Mocking Reference" + url: "/reference/mocking.html" + - title: "Assertions" + url: "/reference/assertions.html" + - title: "Matchers" + url: "/reference/matchers.html" + - title: "Actions" + url: "/reference/actions.html" + - title: "Testing FAQ" + url: "/faq.html" + - title: "Mocking FAQ" + url: "/gmock_faq.html" + - title: "Code Samples" + url: "/samples.html" + - title: "Using pkg-config" + url: "/pkgconfig.html" + - title: "Community Documentation" + url: "/community_created_documentation.html" diff --git a/third_party/googletest/docs/_layouts/default.html b/third_party/googletest/docs/_layouts/default.html new file mode 100644 index 0000000..c7f331b --- /dev/null +++ b/third_party/googletest/docs/_layouts/default.html @@ -0,0 +1,58 @@ + + + + + + + +{% seo %} + + + + + + +
+
+ {{ content }} +
+ +
+ + + + diff --git a/third_party/googletest/docs/_sass/main.scss b/third_party/googletest/docs/_sass/main.scss new file mode 100644 index 0000000..92edc87 --- /dev/null +++ b/third_party/googletest/docs/_sass/main.scss @@ -0,0 +1,200 @@ +// Styles for GoogleTest docs website on GitHub Pages. +// Color variables are defined in +// https://github.com/pages-themes/primer/tree/master/_sass/primer-support/lib/variables + +$sidebar-width: 260px; + +body { + display: flex; + margin: 0; +} + +.sidebar { + background: $black; + color: $text-white; + flex-shrink: 0; + height: 100vh; + overflow: auto; + position: sticky; + top: 0; + width: $sidebar-width; +} + +.sidebar h1 { + font-size: 1.5em; +} + +.sidebar h2 { + color: $gray-light; + font-size: 0.8em; + font-weight: normal; + margin-bottom: 0.8em; + padding-left: 2.5em; + text-transform: uppercase; +} + +.sidebar .header { + background: $black; + padding: 2em; + position: sticky; + top: 0; + width: 100%; +} + +.sidebar .header a { + color: $text-white; + text-decoration: none; +} + +.sidebar .nav-toggle { + display: none; +} + +.sidebar .expander { + cursor: pointer; + display: none; + height: 3em; + position: absolute; + right: 1em; + top: 1.5em; + width: 3em; +} + +.sidebar .expander .arrow { + border: solid $white; + border-width: 0 3px 3px 0; + display: block; + height: 0.7em; + margin: 1em auto; + transform: rotate(45deg); + transition: transform 0.5s; + width: 0.7em; +} + +.sidebar nav { + width: 100%; +} + +.sidebar nav ul { + list-style-type: none; + margin-bottom: 1em; + padding: 0; + + &:last-child { + margin-bottom: 2em; + } + + a { + text-decoration: none; + } + + li { + color: $text-white; + padding-left: 2em; + text-decoration: none; + } + + li.active { + background: $border-gray-darker; + font-weight: bold; + } + + li:hover { + background: $border-gray-darker; + } +} + +.main { + background-color: $bg-gray; + width: calc(100% - #{$sidebar-width}); +} + +.main .main-inner { + background-color: $white; + padding: 2em; +} + +.main .footer { + margin: 0; + padding: 2em; +} + +.main table th { + text-align: left; +} + +.main .callout { + border-left: 0.25em solid $white; + padding: 1em; + + a { + text-decoration: underline; + } + + &.important { + background-color: $bg-yellow-light; + border-color: $bg-yellow; + color: $black; + } + + &.note { + background-color: $bg-blue-light; + border-color: $text-blue; + color: $text-blue; + } + + &.tip { + background-color: $green-000; + border-color: $green-700; + color: $green-700; + } + + &.warning { + background-color: $red-000; + border-color: $text-red; + color: $text-red; + } +} + +.main .good pre { + background-color: $bg-green-light; +} + +.main .bad pre { + background-color: $red-000; +} + +@media all and (max-width: 768px) { + body { + flex-direction: column; + } + + .sidebar { + height: auto; + position: relative; + width: 100%; + } + + .sidebar .expander { + display: block; + } + + .sidebar nav { + height: 0; + overflow: hidden; + } + + .sidebar .nav-toggle:checked { + & ~ nav { + height: auto; + } + + & + .expander .arrow { + transform: rotate(-135deg); + } + } + + .main { + width: 100%; + } +} diff --git a/third_party/googletest/docs/advanced.md b/third_party/googletest/docs/advanced.md new file mode 100644 index 0000000..3871db1 --- /dev/null +++ b/third_party/googletest/docs/advanced.md @@ -0,0 +1,2436 @@ +# Advanced GoogleTest Topics + +## Introduction + +Now that you have read the [GoogleTest Primer](primer.md) and learned how to +write tests using GoogleTest, it's time to learn some new tricks. This document +will show you more assertions as well as how to construct complex failure +messages, propagate fatal failures, reuse and speed up your test fixtures, and +use various flags with your tests. + +## More Assertions + +This section covers some less frequently used, but still significant, +assertions. + +### Explicit Success and Failure + +See [Explicit Success and Failure](reference/assertions.md#success-failure) in +the Assertions Reference. + +### Exception Assertions + +See [Exception Assertions](reference/assertions.md#exceptions) in the Assertions +Reference. + +### Predicate Assertions for Better Error Messages + +Even though GoogleTest has a rich set of assertions, they can never be complete, +as it's impossible (nor a good idea) to anticipate all scenarios a user might +run into. Therefore, sometimes a user has to use `EXPECT_TRUE()` to check a +complex expression, for lack of a better macro. This has the problem of not +showing you the values of the parts of the expression, making it hard to +understand what went wrong. As a workaround, some users choose to construct the +failure message by themselves, streaming it into `EXPECT_TRUE()`. However, this +is awkward especially when the expression has side-effects or is expensive to +evaluate. + +GoogleTest gives you three different options to solve this problem: + +#### Using an Existing Boolean Function + +If you already have a function or functor that returns `bool` (or a type that +can be implicitly converted to `bool`), you can use it in a *predicate +assertion* to get the function arguments printed for free. See +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the Assertions +Reference for details. + +#### Using a Function That Returns an AssertionResult + +While `EXPECT_PRED*()` and friends are handy for a quick job, the syntax is not +satisfactory: you have to use different macros for different arities, and it +feels more like Lisp than C++. The `::testing::AssertionResult` class solves +this problem. + +An `AssertionResult` object represents the result of an assertion (whether it's +a success or a failure, and an associated message). You can create an +`AssertionResult` using one of these factory functions: + +```c++ +namespace testing { + +// Returns an AssertionResult object to indicate that an assertion has +// succeeded. +AssertionResult AssertionSuccess(); + +// Returns an AssertionResult object to indicate that an assertion has +// failed. +AssertionResult AssertionFailure(); + +} +``` + +You can then use the `<<` operator to stream messages to the `AssertionResult` +object. + +To provide more readable messages in Boolean assertions (e.g. `EXPECT_TRUE()`), +write a predicate function that returns `AssertionResult` instead of `bool`. For +example, if you define `IsEven()` as: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess(); + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +instead of: + +```c++ +bool IsEven(int n) { + return (n % 2) == 0; +} +``` + +the failed assertion `EXPECT_TRUE(IsEven(Fib(4)))` will print: + +```none +Value of: IsEven(Fib(4)) + Actual: false (3 is odd) +Expected: true +``` + +instead of a more opaque + +```none +Value of: IsEven(Fib(4)) + Actual: false +Expected: true +``` + +If you want informative messages in `EXPECT_FALSE` and `ASSERT_FALSE` as well +(one third of Boolean assertions in the Google code base are negative ones), and +are fine with making the predicate slower in the success case, you can supply a +success message: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess() << n << " is even"; + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +Then the statement `EXPECT_FALSE(IsEven(Fib(6)))` will print + +```none + Value of: IsEven(Fib(6)) + Actual: true (8 is even) + Expected: false +``` + +#### Using a Predicate-Formatter + +If you find the default message generated by +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) and +[`EXPECT_TRUE`](reference/assertions.md#EXPECT_TRUE) unsatisfactory, or some +arguments to your predicate do not support streaming to `ostream`, you can +instead use *predicate-formatter assertions* to *fully* customize how the +message is formatted. See +[`EXPECT_PRED_FORMAT*`](reference/assertions.md#EXPECT_PRED_FORMAT) in the +Assertions Reference for details. + +### Floating-Point Comparison + +See [Floating-Point Comparison](reference/assertions.md#floating-point) in the +Assertions Reference. + +#### Floating-Point Predicate-Format Functions + +Some floating-point operations are useful, but not that often used. In order to +avoid an explosion of new macros, we provide them as predicate-format functions +that can be used in the predicate assertion macro +[`EXPECT_PRED_FORMAT2`](reference/assertions.md#EXPECT_PRED_FORMAT), for +example: + +```c++ +using ::testing::FloatLE; +using ::testing::DoubleLE; +... +EXPECT_PRED_FORMAT2(FloatLE, val1, val2); +EXPECT_PRED_FORMAT2(DoubleLE, val1, val2); +``` + +The above code verifies that `val1` is less than, or approximately equal to, +`val2`. + +### Asserting Using gMock Matchers + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### More String Assertions + +(Please read the [previous](#asserting-using-gmock-matchers) section first if +you haven't.) + +You can use the gMock [string matchers](reference/matchers.md#string-matchers) +with [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) to do more string +comparison tricks (sub-string, prefix, suffix, regular expression, and etc). For +example, + +```c++ +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +... + ASSERT_THAT(foo_string, HasSubstr("needle")); + EXPECT_THAT(bar_string, MatchesRegex("\\w*\\d+")); +``` + +### Windows HRESULT assertions + +See [Windows HRESULT Assertions](reference/assertions.md#HRESULT) in the +Assertions Reference. + +### Type Assertions + +You can call the function + +```c++ +::testing::StaticAssertTypeEq(); +``` + +to assert that types `T1` and `T2` are the same. The function does nothing if +the assertion is satisfied. If the types are different, the function call will +fail to compile, the compiler error message will say that `T1 and T2 are not the +same type` and most likely (depending on the compiler) show you the actual +values of `T1` and `T2`. This is mainly useful inside template code. + +**Caveat**: When used inside a member function of a class template or a function +template, `StaticAssertTypeEq()` is effective only if the function is +instantiated. For example, given: + +```c++ +template class Foo { + public: + void Bar() { testing::StaticAssertTypeEq(); } +}; +``` + +the code: + +```c++ +void Test1() { Foo foo; } +``` + +will not generate a compiler error, as `Foo::Bar()` is never actually +instantiated. Instead, you need: + +```c++ +void Test2() { Foo foo; foo.Bar(); } +``` + +to cause a compiler error. + +### Assertion Placement + +You can use assertions in any C++ function. In particular, it doesn't have to be +a method of the test fixture class. The one constraint is that assertions that +generate a fatal failure (`FAIL*` and `ASSERT_*`) can only be used in +void-returning functions. This is a consequence of Google's not using +exceptions. By placing it in a non-void function you'll get a confusing compile +error like `"error: void value not ignored as it ought to be"` or `"cannot +initialize return object of type 'bool' with an rvalue of type 'void'"` or +`"error: no viable conversion from 'void' to 'string'"`. + +If you need to use fatal assertions in a function that returns non-void, one +option is to make the function return the value in an out parameter instead. For +example, you can rewrite `T2 Foo(T1 x)` to `void Foo(T1 x, T2* result)`. You +need to make sure that `*result` contains some sensible value even when the +function returns prematurely. As the function now returns `void`, you can use +any assertion inside of it. + +If changing the function's type is not an option, you should just use assertions +that generate non-fatal failures, such as `ADD_FAILURE*` and `EXPECT_*`. + +{: .callout .note} +NOTE: Constructors and destructors are not considered void-returning functions, +according to the C++ language specification, and so you may not use fatal +assertions in them; you'll get a compilation error if you try. Instead, either +call `abort` and crash the entire test executable, or put the fatal assertion in +a `SetUp`/`TearDown` function; see +[constructor/destructor vs. `SetUp`/`TearDown`](faq.md#CtorVsSetUp) + +{: .callout .warning} +WARNING: A fatal assertion in a helper function (private void-returning method) +called from a constructor or destructor does not terminate the current test, as +your intuition might suggest: it merely returns from the constructor or +destructor early, possibly leaving your object in a partially-constructed or +partially-destructed state! You almost certainly want to `abort` or use +`SetUp`/`TearDown` instead. + +## Skipping test execution + +Related to the assertions `SUCCEED()` and `FAIL()`, you can prevent further test +execution at runtime with the `GTEST_SKIP()` macro. This is useful when you need +to check for preconditions of the system under test during runtime and skip +tests in a meaningful way. + +`GTEST_SKIP()` can be used in individual test cases or in the `SetUp()` methods +of classes derived from either `::testing::Environment` or `::testing::Test`. +For example: + +```c++ +TEST(SkipTest, DoesSkip) { + GTEST_SKIP() << "Skipping single test"; + EXPECT_EQ(0, 1); // Won't fail; it won't be executed +} + +class SkipFixture : public ::testing::Test { + protected: + void SetUp() override { + GTEST_SKIP() << "Skipping all tests for this fixture"; + } +}; + +// Tests for SkipFixture won't be executed. +TEST_F(SkipFixture, SkipsOneTest) { + EXPECT_EQ(5, 7); // Won't fail +} +``` + +As with assertion macros, you can stream a custom message into `GTEST_SKIP()`. + +## Teaching GoogleTest How to Print Your Values + +When a test assertion such as `EXPECT_EQ` fails, GoogleTest prints the argument +values to help you debug. It does this using a user-extensible value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. + +As mentioned earlier, the printer is *extensible*. That means you can teach it +to do a better job at printing your particular type than to dump the bytes. To +do that, define an `AbslStringify()` overload as a `friend` function template +for your type: + +```cpp +namespace foo { + +class Point { // We want GoogleTest to be able to print instances of this. + ... + // Provide a friend overload. + template + friend void AbslStringify(Sink& sink, const Point& point) { + absl::Format(&sink, "(%d, %d)", point.x, point.y); + } + + int x; + int y; +}; + +// If you can't declare the function in the class it's important that the +// AbslStringify overload is defined in the SAME namespace that defines Point. +// C++'s look-up rules rely on that. +enum class EnumWithStringify { kMany = 0, kChoices = 1 }; + +template +void AbslStringify(Sink& sink, EnumWithStringify e) { + absl::Format(&sink, "%s", e == EnumWithStringify::kMany ? "Many" : "Choices"); +} + +} // namespace foo +``` + +{: .callout .note} +Note: `AbslStringify()` utilizes a generic "sink" buffer to construct its +string. For more information about supported operations on `AbslStringify()`'s +sink, see go/abslstringify. + +`AbslStringify()` can also use `absl::StrFormat`'s catch-all `%v` type specifier +within its own format strings to perform type deduction. `Point` above could be +formatted as `"(%v, %v)"` for example, and deduce the `int` values as `%d`. + +Sometimes, `AbslStringify()` might not be an option: your team may wish to print +types with extra debugging information for testing purposes only. If so, you can +instead define a `PrintTo()` function like this: + +```c++ +#include + +namespace foo { + +class Point { + ... + friend void PrintTo(const Point& point, std::ostream* os) { + *os << "(" << point.x << "," << point.y << ")"; + } + + int x; + int y; +}; + +// If you can't declare the function in the class it's important that PrintTo() +// is defined in the SAME namespace that defines Point. C++'s look-up rules +// rely on that. +void PrintTo(const Point& point, std::ostream* os) { + *os << "(" << point.x << "," << point.y << ")"; +} + +} // namespace foo +``` + +If you have defined both `AbslStringify()` and `PrintTo()`, the latter will be +used by GoogleTest. This allows you to customize how the value appears in +GoogleTest's output without affecting code that relies on the behavior of +`AbslStringify()`. + +If you have an existing `<<` operator and would like to define an +`AbslStringify()`, the latter will be used for GoogleTest printing. + +If you want to print a value `x` using GoogleTest's value printer yourself, just +call `::testing::PrintToString(x)`, which returns an `std::string`: + +```c++ +vector > point_ints = GetPointIntVector(); + +EXPECT_TRUE(IsCorrectPointIntVector(point_ints)) + << "point_ints = " << testing::PrintToString(point_ints); +``` + +For more details regarding `AbslStringify()` and its integration with other +libraries, see go/abslstringify. + +## Death Tests + +In many applications, there are assertions that can cause application failure if +a condition is not met. These consistency checks, which ensure that the program +is in a known good state, are there to fail at the earliest possible time after +some program state is corrupted. If the assertion checks the wrong condition, +then the program may proceed in an erroneous state, which could lead to memory +corruption, security holes, or worse. Hence it is vitally important to test that +such assertion statements work as expected. + +Since these precondition checks cause the processes to die, we call such tests +_death tests_. More generally, any test that checks that a program terminates +(except by throwing an exception) in an expected fashion is also a death test. + +Note that if a piece of code throws an exception, we don't consider it "death" +for the purpose of death tests, as the caller of the code could catch the +exception and avoid the crash. If you want to verify exceptions thrown by your +code, see [Exception Assertions](#ExceptionAssertions). + +If you want to test `EXPECT_*()/ASSERT_*()` failures in your test code, see +["Catching" Failures](#catching-failures). + +### How to Write a Death Test + +GoogleTest provides assertion macros to support death tests. See +[Death Assertions](reference/assertions.md#death) in the Assertions Reference +for details. + +To write a death test, simply use one of the macros inside your test function. +For example, + +```c++ +TEST(MyDeathTest, Foo) { + // This death test uses a compound statement. + ASSERT_DEATH({ + int n = 5; + Foo(&n); + }, "Error on line .* of Foo()"); +} + +TEST(MyDeathTest, NormalExit) { + EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +} + +TEST(MyDeathTest, KillProcess) { + EXPECT_EXIT(KillProcess(), testing::KilledBySignal(SIGKILL), + "Sending myself unblockable signal"); +} +``` + +verifies that: + +* calling `Foo(5)` causes the process to die with the given error message, +* calling `NormalExit()` causes the process to print `"Success"` to stderr and + exit with exit code 0, and +* calling `KillProcess()` kills the process with signal `SIGKILL`. + +The test function body may contain other assertions and statements as well, if +necessary. + +Note that a death test only cares about three things: + +1. does `statement` abort or exit the process? +2. (in the case of `ASSERT_EXIT` and `EXPECT_EXIT`) does the exit status + satisfy `predicate`? Or (in the case of `ASSERT_DEATH` and `EXPECT_DEATH`) + is the exit status non-zero? And +3. does the stderr output match `matcher`? + +In particular, if `statement` generates an `ASSERT_*` or `EXPECT_*` failure, it +will **not** cause the death test to fail, as GoogleTest assertions don't abort +the process. + +### Death Test Naming + +{: .callout .important} +IMPORTANT: We strongly recommend you to follow the convention of naming your +**test suite** (not test) `*DeathTest` when it contains a death test, as +demonstrated in the above example. The +[Death Tests And Threads](#death-tests-and-threads) section below explains why. + +If a test fixture class is shared by normal tests and death tests, you can use +`using` or `typedef` to introduce an alias for the fixture class and avoid +duplicating its code: + +```c++ +class FooTest : public testing::Test { ... }; + +using FooDeathTest = FooTest; + +TEST_F(FooTest, DoesThis) { + // normal test +} + +TEST_F(FooDeathTest, DoesThat) { + // death test +} +``` + +### Regular Expression Syntax + +When built with Bazel and using Abseil, GoogleTest uses the +[RE2](https://github.com/google/re2/wiki/Syntax) syntax. Otherwise, for POSIX +systems (Linux, Cygwin, Mac), GoogleTest uses the +[POSIX extended regular expression](http://www.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap09.html#tag_09_04) +syntax. To learn about POSIX syntax, you may want to read this +[Wikipedia entry](http://en.wikipedia.org/wiki/Regular_expression#POSIX_extended). + +On Windows, GoogleTest uses its own simple regular expression implementation. It +lacks many features. For example, we don't support union (`"x|y"`), grouping +(`"(xy)"`), brackets (`"[xy]"`), and repetition count (`"x{5,7}"`), among +others. Below is what we do support (`A` denotes a literal character, period +(`.`), or a single `\\ ` escape sequence; `x` and `y` denote regular +expressions.): + +Expression | Meaning +---------- | -------------------------------------------------------------- +`c` | matches any literal character `c` +`\\d` | matches any decimal digit +`\\D` | matches any character that's not a decimal digit +`\\f` | matches `\f` +`\\n` | matches `\n` +`\\r` | matches `\r` +`\\s` | matches any ASCII whitespace, including `\n` +`\\S` | matches any character that's not a whitespace +`\\t` | matches `\t` +`\\v` | matches `\v` +`\\w` | matches any letter, `_`, or decimal digit +`\\W` | matches any character that `\\w` doesn't match +`\\c` | matches any literal character `c`, which must be a punctuation +`.` | matches any single character except `\n` +`A?` | matches 0 or 1 occurrences of `A` +`A*` | matches 0 or many occurrences of `A` +`A+` | matches 1 or many occurrences of `A` +`^` | matches the beginning of a string (not that of each line) +`$` | matches the end of a string (not that of each line) +`xy` | matches `x` followed by `y` + +To help you determine which capability is available on your system, GoogleTest +defines macros to govern which regular expression it is using. The macros are: +`GTEST_USES_SIMPLE_RE=1` or `GTEST_USES_POSIX_RE=1`. If you want your death +tests to work in all cases, you can either `#if` on these macros or use the more +limited syntax only. + +### How It Works + +See [Death Assertions](reference/assertions.md#death) in the Assertions +Reference. + +### Death Tests And Threads + +The reason for the two death test styles has to do with thread safety. Due to +well-known problems with forking in the presence of threads, death tests should +be run in a single-threaded context. Sometimes, however, it isn't feasible to +arrange that kind of environment. For example, statically-initialized modules +may start threads before main is ever reached. Once threads have been created, +it may be difficult or impossible to clean them up. + +GoogleTest has three features intended to raise awareness of threading issues. + +1. A warning is emitted if multiple threads are running when a death test is + encountered. +2. Test suites with a name ending in "DeathTest" are run before all other + tests. +3. It uses `clone()` instead of `fork()` to spawn the child process on Linux + (`clone()` is not available on Cygwin and Mac), as `fork()` is more likely + to cause the child to hang when the parent process has multiple threads. + +It's perfectly fine to create threads inside a death test statement; they are +executed in a separate process and cannot affect the parent. + +### Death Test Styles + +The "threadsafe" death test style was introduced in order to help mitigate the +risks of testing in a possibly multithreaded environment. It trades increased +test execution time (potentially dramatically so) for improved thread safety. + +The automated testing framework does not set the style flag. You can choose a +particular style of death tests by setting the flag programmatically: + +```c++ +GTEST_FLAG_SET(death_test_style, "threadsafe"); +``` + +You can do this in `main()` to set the style for all death tests in the binary, +or in individual tests. Recall that flags are saved before running each test and +restored afterwards, so you need not do that yourself. For example: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + GTEST_FLAG_SET(death_test_style, "fast"); + return RUN_ALL_TESTS(); +} + +TEST(MyDeathTest, TestOne) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + // This test is run in the "threadsafe" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} + +TEST(MyDeathTest, TestTwo) { + // This test is run in the "fast" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} +``` + +### Caveats + +The `statement` argument of `ASSERT_EXIT()` can be any valid C++ statement. If +it leaves the current function via a `return` statement or by throwing an +exception, the death test is considered to have failed. Some GoogleTest macros +may return from the current function (e.g. `ASSERT_TRUE()`), so be sure to avoid +them in `statement`. + +Since `statement` runs in the child process, any in-memory side effect (e.g. +modifying a variable, releasing memory, etc) it causes will *not* be observable +in the parent process. In particular, if you release memory in a death test, +your program will fail the heap check as the parent process will never see the +memory reclaimed. To solve this problem, you can + +1. try not to free memory in a death test; +2. free the memory again in the parent process; or +3. do not use the heap checker in your program. + +Due to an implementation detail, you cannot place multiple death test assertions +on the same line; otherwise, compilation will fail with an unobvious error +message. + +Despite the improved thread safety afforded by the "threadsafe" style of death +test, thread problems such as deadlock are still possible in the presence of +handlers registered with `pthread_atfork(3)`. + +## Using Assertions in Sub-routines + +{: .callout .note} +Note: If you want to put a series of test assertions in a subroutine to check +for a complex condition, consider using +[a custom GMock matcher](gmock_cook_book.md#NewMatchers) instead. This lets you +provide a more readable error message in case of failure and avoid all of the +issues described below. + +### Adding Traces to Assertions + +If a test sub-routine is called from several places, when an assertion inside it +fails, it can be hard to tell which invocation of the sub-routine the failure is +from. You can alleviate this problem using extra logging or custom failure +messages, but that usually clutters up your tests. A better solution is to use +the `SCOPED_TRACE` macro or the `ScopedTrace` utility: + +```c++ +SCOPED_TRACE(message); +``` + +```c++ +ScopedTrace trace("file_path", line_number, message); +``` + +where `message` can be anything streamable to `std::ostream`. `SCOPED_TRACE` +macro will cause the current file name, line number, and the given message to be +added in every failure message. `ScopedTrace` accepts explicit file name and +line number in arguments, which is useful for writing test helpers. The effect +will be undone when the control leaves the current lexical scope. + +For example, + +```c++ +10: void Sub1(int n) { +11: EXPECT_EQ(Bar(n), 1); +12: EXPECT_EQ(Bar(n + 1), 2); +13: } +14: +15: TEST(FooTest, Bar) { +16: { +17: SCOPED_TRACE("A"); // This trace point will be included in +18: // every failure in this scope. +19: Sub1(1); +20: } +21: // Now it won't. +22: Sub1(9); +23: } +``` + +could result in messages like these: + +```none +path/to/foo_test.cc:11: Failure +Value of: Bar(n) +Expected: 1 + Actual: 2 +Google Test trace: +path/to/foo_test.cc:17: A + +path/to/foo_test.cc:12: Failure +Value of: Bar(n + 1) +Expected: 2 + Actual: 3 +``` + +Without the trace, it would've been difficult to know which invocation of +`Sub1()` the two failures come from respectively. (You could add an extra +message to each assertion in `Sub1()` to indicate the value of `n`, but that's +tedious.) + +Some tips on using `SCOPED_TRACE`: + +1. With a suitable message, it's often enough to use `SCOPED_TRACE` at the + beginning of a sub-routine, instead of at each call site. +2. When calling sub-routines inside a loop, make the loop iterator part of the + message in `SCOPED_TRACE` such that you can know which iteration the failure + is from. +3. Sometimes the line number of the trace point is enough for identifying the + particular invocation of a sub-routine. In this case, you don't have to + choose a unique message for `SCOPED_TRACE`. You can simply use `""`. +4. You can use `SCOPED_TRACE` in an inner scope when there is one in the outer + scope. In this case, all active trace points will be included in the failure + messages, in reverse order they are encountered. +5. The trace dump is clickable in Emacs - hit `return` on a line number and + you'll be taken to that line in the source file! + +### Propagating Fatal Failures + +A common pitfall when using `ASSERT_*` and `FAIL*` is not understanding that +when they fail they only abort the _current function_, not the entire test. For +example, the following test will segfault: + +```c++ +void Subroutine() { + // Generates a fatal failure and aborts the current function. + ASSERT_EQ(1, 2); + + // The following won't be executed. + ... +} + +TEST(FooTest, Bar) { + Subroutine(); // The intended behavior is for the fatal failure + // in Subroutine() to abort the entire test. + + // The actual behavior: the function goes on after Subroutine() returns. + int* p = nullptr; + *p = 3; // Segfault! +} +``` + +To alleviate this, GoogleTest provides three different solutions. You could use +either exceptions, the `(ASSERT|EXPECT)_NO_FATAL_FAILURE` assertions or the +`HasFatalFailure()` function. They are described in the following two +subsections. + +#### Asserting on Subroutines with an exception + +The following code can turn ASSERT-failure into an exception: + +```c++ +class ThrowListener : public testing::EmptyTestEventListener { + void OnTestPartResult(const testing::TestPartResult& result) override { + if (result.type() == testing::TestPartResult::kFatalFailure) { + throw testing::AssertionException(result); + } + } +}; +int main(int argc, char** argv) { + ... + testing::UnitTest::GetInstance()->listeners().Append(new ThrowListener); + return RUN_ALL_TESTS(); +} +``` + +This listener should be added after other listeners if you have any, otherwise +they won't see failed `OnTestPartResult`. + +#### Asserting on Subroutines + +As shown above, if your test calls a subroutine that has an `ASSERT_*` failure +in it, the test will continue after the subroutine returns. This may not be what +you want. + +Often people want fatal failures to propagate like exceptions. For that +GoogleTest offers the following macros: + +Fatal assertion | Nonfatal assertion | Verifies +------------------------------------- | ------------------------------------- | -------- +`ASSERT_NO_FATAL_FAILURE(statement);` | `EXPECT_NO_FATAL_FAILURE(statement);` | `statement` doesn't generate any new fatal failures in the current thread. + +Only failures in the thread that executes the assertion are checked to determine +the result of this type of assertions. If `statement` creates new threads, +failures in these threads are ignored. + +Examples: + +```c++ +ASSERT_NO_FATAL_FAILURE(Foo()); + +int i; +EXPECT_NO_FATAL_FAILURE({ + i = Bar(); +}); +``` + +Assertions from multiple threads are currently not supported on Windows. + +#### Checking for Failures in the Current Test + +`HasFatalFailure()` in the `::testing::Test` class returns `true` if an +assertion in the current test has suffered a fatal failure. This allows +functions to catch fatal failures in a sub-routine and return early. + +```c++ +class Test { + public: + ... + static bool HasFatalFailure(); +}; +``` + +The typical usage, which basically simulates the behavior of a thrown exception, +is: + +```c++ +TEST(FooTest, Bar) { + Subroutine(); + // Aborts if Subroutine() had a fatal failure. + if (HasFatalFailure()) return; + + // The following won't be executed. + ... +} +``` + +If `HasFatalFailure()` is used outside of `TEST()` , `TEST_F()` , or a test +fixture, you must add the `::testing::Test::` prefix, as in: + +```c++ +if (testing::Test::HasFatalFailure()) return; +``` + +Similarly, `HasNonfatalFailure()` returns `true` if the current test has at +least one non-fatal failure, and `HasFailure()` returns `true` if the current +test has at least one failure of either kind. + +## Logging Additional Information + +In your test code, you can call `RecordProperty("key", value)` to log additional +information, where `value` can be either a string or an `int`. The *last* value +recorded for a key will be emitted to the +[XML output](#generating-an-xml-report) if you specify one. For example, the +test + +```c++ +TEST_F(WidgetUsageTest, MinAndMaxWidgets) { + RecordProperty("MaximumWidgets", ComputeMaxUsage()); + RecordProperty("MinimumWidgets", ComputeMinUsage()); +} +``` + +will output XML like this: + +```xml + ... + + ... +``` + +{: .callout .note} +> NOTE: +> +> * `RecordProperty()` is a static member of the `Test` class. Therefore it +> needs to be prefixed with `::testing::Test::` if used outside of the +> `TEST` body and the test fixture class. +> * *`key`* must be a valid XML attribute name, and cannot conflict with the +> ones already used by GoogleTest (`name`, `status`, `time`, `classname`, +> `type_param`, and `value_param`). +> * Calling `RecordProperty()` outside of the lifespan of a test is allowed. +> If it's called outside of a test but between a test suite's +> `SetUpTestSuite()` and `TearDownTestSuite()` methods, it will be +> attributed to the XML element for the test suite. If it's called outside +> of all test suites (e.g. in a test environment), it will be attributed to +> the top-level XML element. + +## Sharing Resources Between Tests in the Same Test Suite + +GoogleTest creates a new test fixture object for each test in order to make +tests independent and easier to debug. However, sometimes tests use resources +that are expensive to set up, making the one-copy-per-test model prohibitively +expensive. + +If the tests don't change the resource, there's no harm in their sharing a +single resource copy. So, in addition to per-test set-up/tear-down, GoogleTest +also supports per-test-suite set-up/tear-down. To use it: + +1. In your test fixture class (say `FooTest` ), declare as `static` some member + variables to hold the shared resources. +2. Outside your test fixture class (typically just below it), define those + member variables, optionally giving them initial values. +3. In the same test fixture class, define a `static void SetUpTestSuite()` + function (remember not to spell it as **`SetupTestSuite`** with a small + `u`!) to set up the shared resources and a `static void TearDownTestSuite()` + function to tear them down. + +That's it! GoogleTest automatically calls `SetUpTestSuite()` before running the +*first test* in the `FooTest` test suite (i.e. before creating the first +`FooTest` object), and calls `TearDownTestSuite()` after running the *last test* +in it (i.e. after deleting the last `FooTest` object). In between, the tests can +use the shared resources. + +Remember that the test order is undefined, so your code can't depend on a test +preceding or following another. Also, the tests must either not modify the state +of any shared resource, or, if they do modify the state, they must restore the +state to its original value before passing control to the next test. + +Note that `SetUpTestSuite()` may be called multiple times for a test fixture +class that has derived classes, so you should not expect code in the function +body to be run only once. Also, derived classes still have access to shared +resources defined as static members, so careful consideration is needed when +managing shared resources to avoid memory leaks if shared resources are not +properly cleaned up in `TearDownTestSuite()`. + +Here's an example of per-test-suite set-up and tear-down: + +```c++ +class FooTest : public testing::Test { + protected: + // Per-test-suite set-up. + // Called before the first test in this test suite. + // Can be omitted if not needed. + static void SetUpTestSuite() { + shared_resource_ = new ...; + + // If `shared_resource_` is **not deleted** in `TearDownTestSuite()`, + // reallocation should be prevented because `SetUpTestSuite()` may be called + // in subclasses of FooTest and lead to memory leak. + // + // if (shared_resource_ == nullptr) { + // shared_resource_ = new ...; + // } + } + + // Per-test-suite tear-down. + // Called after the last test in this test suite. + // Can be omitted if not needed. + static void TearDownTestSuite() { + delete shared_resource_; + shared_resource_ = nullptr; + } + + // You can define per-test set-up logic as usual. + void SetUp() override { ... } + + // You can define per-test tear-down logic as usual. + void TearDown() override { ... } + + // Some expensive resource shared by all tests. + static T* shared_resource_; +}; + +T* FooTest::shared_resource_ = nullptr; + +TEST_F(FooTest, Test1) { + ... you can refer to shared_resource_ here ... +} + +TEST_F(FooTest, Test2) { + ... you can refer to shared_resource_ here ... +} +``` + +{: .callout .note} +NOTE: Though the above code declares `SetUpTestSuite()` protected, it may +sometimes be necessary to declare it public, such as when using it with +`TEST_P`. + +## Global Set-Up and Tear-Down + +Just as you can do set-up and tear-down at the test level and the test suite +level, you can also do it at the test program level. Here's how. + +First, you subclass the `::testing::Environment` class to define a test +environment, which knows how to set-up and tear-down: + +```c++ +class Environment : public ::testing::Environment { + public: + ~Environment() override {} + + // Override this to define how to set up the environment. + void SetUp() override {} + + // Override this to define how to tear down the environment. + void TearDown() override {} +}; +``` + +Then, you register an instance of your environment class with GoogleTest by +calling the `::testing::AddGlobalTestEnvironment()` function: + +```c++ +Environment* AddGlobalTestEnvironment(Environment* env); +``` + +Now, when `RUN_ALL_TESTS()` is called, it first calls the `SetUp()` method of +each environment object, then runs the tests if none of the environments +reported fatal failures and `GTEST_SKIP()` was not called. `RUN_ALL_TESTS()` +always calls `TearDown()` with each environment object, regardless of whether or +not the tests were run. + +It's OK to register multiple environment objects. In this suite, their `SetUp()` +will be called in the order they are registered, and their `TearDown()` will be +called in the reverse order. + +Note that GoogleTest takes ownership of the registered environment objects. +Therefore **do not delete them** by yourself. + +You should call `AddGlobalTestEnvironment()` before `RUN_ALL_TESTS()` is called, +probably in `main()`. If you use `gtest_main`, you need to call this before +`main()` starts for it to take effect. One way to do this is to define a global +variable like this: + +```c++ +testing::Environment* const foo_env = + testing::AddGlobalTestEnvironment(new FooEnvironment); +``` + +However, we strongly recommend you to write your own `main()` and call +`AddGlobalTestEnvironment()` there, as relying on initialization of global +variables makes the code harder to read and may cause problems when you register +multiple environments from different translation units and the environments have +dependencies among them (remember that the compiler doesn't guarantee the order +in which global variables from different translation units are initialized). + +## Value-Parameterized Tests + +*Value-parameterized tests* allow you to test your code with different +parameters without writing multiple copies of the same test. This is useful in a +number of situations, for example: + +* You have a piece of code whose behavior is affected by one or more + command-line flags. You want to make sure your code performs correctly for + various values of those flags. +* You want to test different implementations of an OO interface. +* You want to test your code over various inputs (a.k.a. data-driven testing). + This feature is easy to abuse, so please exercise your good sense when doing + it! + +### How to Write Value-Parameterized Tests + +To write value-parameterized tests, first you should define a fixture class. It +must be derived from both `testing::Test` and `testing::WithParamInterface` +(the latter is a pure interface), where `T` is the type of your parameter +values. For convenience, you can just derive the fixture class from +`testing::TestWithParam`, which itself is derived from both `testing::Test` +and `testing::WithParamInterface`. `T` can be any copyable type. If it's a +raw pointer, you are responsible for managing the lifespan of the pointed +values. + +{: .callout .note} +NOTE: If your test fixture defines `SetUpTestSuite()` or `TearDownTestSuite()` +they must be declared **public** rather than **protected** in order to use +`TEST_P`. + +```c++ +class FooTest : + public testing::TestWithParam { + // You can implement all the usual fixture class members here. + // To access the test parameter, call GetParam() from class + // TestWithParam. +}; + +// Or, when you want to add parameters to a pre-existing fixture class: +class BaseTest : public testing::Test { + ... +}; +class BarTest : public BaseTest, + public testing::WithParamInterface { + ... +}; +``` + +Then, use the `TEST_P` macro to define as many test patterns using this fixture +as you want. The `_P` suffix is for "parameterized" or "pattern", whichever you +prefer to think. + +```c++ +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} +``` + +Finally, you can use the `INSTANTIATE_TEST_SUITE_P` macro to instantiate the +test suite with any set of parameters you want. GoogleTest defines a number of +functions for generating test parameters—see details at +[`INSTANTIATE_TEST_SUITE_P`](reference/testing.md#INSTANTIATE_TEST_SUITE_P) in +the Testing Reference. + +For example, the following statement will instantiate tests from the `FooTest` +test suite each with parameter values `"meeny"`, `"miny"`, and `"moe"` using the +[`Values`](reference/testing.md#param-generators) parameter generator: + +```c++ +INSTANTIATE_TEST_SUITE_P(MeenyMinyMoe, + FooTest, + testing::Values("meeny", "miny", "moe")); +``` + +{: .callout .note} +NOTE: The code above must be placed at global or namespace scope, not at +function scope. + +The first argument to `INSTANTIATE_TEST_SUITE_P` is a unique name for the +instantiation of the test suite. The next argument is the name of the test +pattern, and the last is the +[parameter generator](reference/testing.md#param-generators). + +The parameter generator expression is not evaluated until GoogleTest is +initialized (via `InitGoogleTest()`). Any prior initialization done in the +`main` function will be accessible from the parameter generator, for example, +the results of flag parsing. + +You can instantiate a test pattern more than once, so to distinguish different +instances of the pattern, the instantiation name is added as a prefix to the +actual test suite name. Remember to pick unique prefixes for different +instantiations. The tests from the instantiation above will have these names: + +* `MeenyMinyMoe/FooTest.DoesBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.DoesBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.DoesBlah/2` for `"moe"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/2` for `"moe"` + +You can use these names in [`--gtest_filter`](#running-a-subset-of-the-tests). + +The following statement will instantiate all tests from `FooTest` again, each +with parameter values `"cat"` and `"dog"` using the +[`ValuesIn`](reference/testing.md#param-generators) parameter generator: + +```c++ +constexpr absl::string_view kPets[] = {"cat", "dog"}; +INSTANTIATE_TEST_SUITE_P(Pets, FooTest, testing::ValuesIn(kPets)); +``` + +The tests from the instantiation above will have these names: + +* `Pets/FooTest.DoesBlah/0` for `"cat"` +* `Pets/FooTest.DoesBlah/1` for `"dog"` +* `Pets/FooTest.HasBlahBlah/0` for `"cat"` +* `Pets/FooTest.HasBlahBlah/1` for `"dog"` + +Please note that `INSTANTIATE_TEST_SUITE_P` will instantiate *all* tests in the +given test suite, whether their definitions come before or *after* the +`INSTANTIATE_TEST_SUITE_P` statement. + +Additionally, by default, every `TEST_P` without a corresponding +`INSTANTIATE_TEST_SUITE_P` causes a failing test in test suite +`GoogleTestVerification`. If you have a test suite where that omission is not an +error, for example it is in a library that may be linked in for other reasons or +where the list of test cases is dynamic and may be empty, then this check can be +suppressed by tagging the test suite: + +```c++ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(FooTest); +``` + +You can see [sample7_unittest.cc] and [sample8_unittest.cc] for more examples. + +[sample7_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample7_unittest.cc "Parameterized Test example" +[sample8_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc "Parameterized Test example with multiple parameters" + +### Creating Value-Parameterized Abstract Tests + +In the above, we define and instantiate `FooTest` in the *same* source file. +Sometimes you may want to define value-parameterized tests in a library and let +other people instantiate them later. This pattern is known as *abstract tests*. +As an example of its application, when you are designing an interface you can +write a standard suite of abstract tests (perhaps using a factory function as +the test parameter) that all implementations of the interface are expected to +pass. When someone implements the interface, they can instantiate your suite to +get all the interface-conformance tests for free. + +To define abstract tests, you should organize your code like this: + +1. Put the definition of the parameterized test fixture class (e.g. `FooTest`) + in a header file, say `foo_param_test.h`. Think of this as *declaring* your + abstract tests. +2. Put the `TEST_P` definitions in `foo_param_test.cc`, which includes + `foo_param_test.h`. Think of this as *implementing* your abstract tests. + +Once they are defined, you can instantiate them by including `foo_param_test.h`, +invoking `INSTANTIATE_TEST_SUITE_P()`, and depending on the library target that +contains `foo_param_test.cc`. You can instantiate the same abstract test suite +multiple times, possibly in different source files. + +### Specifying Names for Value-Parameterized Test Parameters + +The optional last argument to `INSTANTIATE_TEST_SUITE_P()` allows the user to +specify a function or functor that generates custom test name suffixes based on +the test parameters. The function should accept one argument of type +`testing::TestParamInfo`, and return `std::string`. + +`testing::PrintToStringParamName` is a builtin test suffix generator that +returns the value of `testing::PrintToString(GetParam())`. It does not work for +`std::string` or C strings. + +{: .callout .note} +NOTE: test names must be non-empty, unique, and may only contain ASCII +alphanumeric characters. In particular, they +[should not contain underscores](faq.md#why-should-test-suite-names-and-test-names-not-contain-underscore) + +```c++ +class MyTestSuite : public testing::TestWithParam {}; + +TEST_P(MyTestSuite, MyTest) +{ + std::cout << "Example Test Param: " << GetParam() << std::endl; +} + +INSTANTIATE_TEST_SUITE_P(MyGroup, MyTestSuite, testing::Range(0, 10), + testing::PrintToStringParamName()); +``` + +Providing a custom functor allows for more control over test parameter name +generation, especially for types where the automatic conversion does not +generate helpful parameter names (e.g. strings as demonstrated above). The +following example illustrates this for multiple parameters, an enumeration type +and a string, and also demonstrates how to combine generators. It uses a lambda +for conciseness: + +```c++ +enum class MyType { MY_FOO = 0, MY_BAR = 1 }; + +class MyTestSuite : public testing::TestWithParam> { +}; + +INSTANTIATE_TEST_SUITE_P( + MyGroup, MyTestSuite, + testing::Combine( + testing::Values(MyType::MY_FOO, MyType::MY_BAR), + testing::Values("A", "B")), + [](const testing::TestParamInfo& info) { + std::string name = absl::StrCat( + std::get<0>(info.param) == MyType::MY_FOO ? "Foo" : "Bar", + std::get<1>(info.param)); + absl::c_replace_if(name, [](char c) { return !std::isalnum(c); }, '_'); + return name; + }); +``` + +## Typed Tests + +Suppose you have multiple implementations of the same interface and want to make +sure that all of them satisfy some common requirements. Or, you may have defined +several types that are supposed to conform to the same "concept" and you want to +verify it. In both cases, you want the same test logic repeated for different +types. + +While you can write one `TEST` or `TEST_F` for each type you want to test (and +you may even factor the test logic into a function template that you invoke from +the `TEST`), it's tedious and doesn't scale: if you want `m` tests over `n` +types, you'll end up writing `m*n` `TEST`s. + +*Typed tests* allow you to repeat the same test logic over a list of types. You +only need to write the test logic once, although you must know the type list +when writing typed tests. Here's how you do it: + +First, define a fixture class template. It should be parameterized by a type. +Remember to derive it from `::testing::Test`: + +```c++ +template +class FooTest : public testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +Next, associate a list of types with the test suite, which will be repeated for +each type in the list: + +```c++ +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(FooTest, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. Otherwise the compiler will think that each comma in +the type list introduces a new macro argument. + +Then, use `TYPED_TEST()` instead of `TEST_F()` to define a typed test for this +test suite. You can repeat this as many times as you want: + +```c++ +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } +``` + +You can see [sample6_unittest.cc] for a complete example. + +[sample6_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample6_unittest.cc "Typed Test example" + +## Type-Parameterized Tests + +*Type-parameterized tests* are like typed tests, except that they don't require +you to know the list of types ahead of time. Instead, you can define the test +logic first and instantiate it with different type lists later. You can even +instantiate it more than once in the same program. + +If you are designing an interface or concept, you can define a suite of +type-parameterized tests to verify properties that any valid implementation of +the interface/concept should have. Then, the author of each implementation can +just instantiate the test suite with their type to verify that it conforms to +the requirements, without having to write similar tests repeatedly. Here's an +example: + +First, define a fixture class template, as we did with typed tests: + +```c++ +template +class FooTest : public testing::Test { + void DoSomethingInteresting(); + ... +}; +``` + +Next, declare that you will define a type-parameterized test suite: + +```c++ +TYPED_TEST_SUITE_P(FooTest); +``` + +Then, use `TYPED_TEST_P()` to define a type-parameterized test. You can repeat +this as many times as you want: + +```c++ +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + + // You will need to use `this` explicitly to refer to fixture members. + this->DoSomethingInteresting() + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } +``` + +Now the tricky part: you need to register all test patterns using the +`REGISTER_TYPED_TEST_SUITE_P` macro before you can instantiate them. The first +argument of the macro is the test suite name; the rest are the names of the +tests in this test suite: + +```c++ +REGISTER_TYPED_TEST_SUITE_P(FooTest, + DoesBlah, HasPropertyA); +``` + +Finally, you are free to instantiate the pattern with the types you want. If you +put the above code in a header file, you can `#include` it in multiple C++ +source files and instantiate it multiple times. + +```c++ +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes); +``` + +To distinguish different instances of the pattern, the first argument to the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro is a prefix that will be added to the +actual test suite name. Remember to pick unique prefixes for different +instances. + +In the special case where the type list contains only one type, you can write +that type directly without `::testing::Types<...>`, like this: + +```c++ +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, int); +``` + +You can see [sample6_unittest.cc] for a complete example. + +## Testing Private Code + +If you change your software's internal implementation, your tests should not +break as long as the change is not observable by users. Therefore, **per the +black-box testing principle, most of the time you should test your code through +its public interfaces.** + +**If you still find yourself needing to test internal implementation code, +consider if there's a better design.** The desire to test internal +implementation is often a sign that the class is doing too much. Consider +extracting an implementation class, and testing it. Then use that implementation +class in the original class. + +If you absolutely have to test non-public interface code though, you can. There +are two cases to consider: + +* Static functions ( *not* the same as static member functions!) or unnamed + namespaces, and +* Private or protected class members + +To test them, we use the following special techniques: + +* Both static functions and definitions/declarations in an unnamed namespace + are only visible within the same translation unit. To test them, you can + `#include` the entire `.cc` file being tested in your `*_test.cc` file. + (#including `.cc` files is not a good way to reuse code - you should not do + this in production code!) + + However, a better approach is to move the private code into the + `foo::internal` namespace, where `foo` is the namespace your project + normally uses, and put the private declarations in a `*-internal.h` file. + Your production `.cc` files and your tests are allowed to include this + internal header, but your clients are not. This way, you can fully test your + internal implementation without leaking it to your clients. + +* Private class members are only accessible from within the class or by + friends. To access a class' private members, you can declare your test + fixture as a friend to the class and define accessors in your fixture. Tests + using the fixture can then access the private members of your production + class via the accessors in the fixture. Note that even though your fixture + is a friend to your production class, your tests are not automatically + friends to it, as they are technically defined in sub-classes of the + fixture. + + Another way to test private members is to refactor them into an + implementation class, which is then declared in a `*-internal.h` file. Your + clients aren't allowed to include this header but your tests can. Such is + called the + [Pimpl](https://www.gamedev.net/articles/programming/general-and-gameplay-programming/the-c-pimpl-r1794/) + (Private Implementation) idiom. + + Or, you can declare an individual test as a friend of your class by adding + this line in the class body: + + ```c++ + FRIEND_TEST(TestSuiteName, TestName); + ``` + + For example, + + ```c++ + // foo.h + class Foo { + ... + private: + FRIEND_TEST(FooTest, BarReturnsZeroOnNull); + + int Bar(void* x); + }; + + // foo_test.cc + ... + TEST(FooTest, BarReturnsZeroOnNull) { + Foo foo; + EXPECT_EQ(foo.Bar(NULL), 0); // Uses Foo's private member Bar(). + } + ``` + + Pay special attention when your class is defined in a namespace. If you want + your test fixtures and tests to be friends of your class, then they must be + defined in the exact same namespace (no anonymous or inline namespaces). + + For example, if the code to be tested looks like: + + ```c++ + namespace my_namespace { + + class Foo { + friend class FooTest; + FRIEND_TEST(FooTest, Bar); + FRIEND_TEST(FooTest, Baz); + ... definition of the class Foo ... + }; + + } // namespace my_namespace + ``` + + Your test code should be something like: + + ```c++ + namespace my_namespace { + + class FooTest : public testing::Test { + protected: + ... + }; + + TEST_F(FooTest, Bar) { ... } + TEST_F(FooTest, Baz) { ... } + + } // namespace my_namespace + ``` + +## "Catching" Failures + +If you are building a testing utility on top of GoogleTest, you'll want to test +your utility. What framework would you use to test it? GoogleTest, of course. + +The challenge is to verify that your testing utility reports failures correctly. +In frameworks that report a failure by throwing an exception, you could catch +the exception and assert on it. But GoogleTest doesn't use exceptions, so how do +we test that a piece of code generates an expected failure? + +`"gtest/gtest-spi.h"` contains some constructs to do this. +After #including this header, you can use + +```c++ + EXPECT_FATAL_FAILURE(statement, substring); +``` + +to assert that `statement` generates a fatal (e.g. `ASSERT_*`) failure in the +current thread whose message contains the given `substring`, or use + +```c++ + EXPECT_NONFATAL_FAILURE(statement, substring); +``` + +if you are expecting a non-fatal (e.g. `EXPECT_*`) failure. + +Only failures in the current thread are checked to determine the result of this +type of expectations. If `statement` creates new threads, failures in these +threads are also ignored. If you want to catch failures in other threads as +well, use one of the following macros instead: + +```c++ + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substring); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substring); +``` + +{: .callout .note} +NOTE: Assertions from multiple threads are currently not supported on Windows. + +For technical reasons, there are some caveats: + +1. You cannot stream a failure message to either macro. + +2. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot reference + local non-static variables or non-static members of `this` object. + +3. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot return a + value. + +## Registering tests programmatically + +The `TEST` macros handle the vast majority of all use cases, but there are few +where runtime registration logic is required. For those cases, the framework +provides the `::testing::RegisterTest` that allows callers to register arbitrary +tests dynamically. + +This is an advanced API only to be used when the `TEST` macros are insufficient. +The macros should be preferred when possible, as they avoid most of the +complexity of calling this function. + +It provides the following signature: + +```c++ +template +TestInfo* RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory); +``` + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the Test object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` for it. + +Must be called before `RUN_ALL_TESTS()` is invoked, otherwise behavior is +undefined. + +Use case example: + +```c++ +class MyFixture : public testing::Test { + public: + // All of these optional, just like in regular macro usage. + static void SetUpTestSuite() { ... } + static void TearDownTestSuite() { ... } + void SetUp() override { ... } + void TearDown() override { ... } +}; + +class MyTest : public MyFixture { + public: + explicit MyTest(int data) : data_(data) {} + void TestBody() override { ... } + + private: + int data_; +}; + +void RegisterMyTests(const std::vector& values) { + for (int v : values) { + testing::RegisterTest( + "MyFixture", ("Test" + std::to_string(v)).c_str(), nullptr, + std::to_string(v).c_str(), + __FILE__, __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MyFixture* { return new MyTest(v); }); + } +} +... +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + std::vector values_to_test = LoadValuesFromConfig(); + RegisterMyTests(values_to_test); + ... + return RUN_ALL_TESTS(); +} +``` + +## Getting the Current Test's Name + +Sometimes a function may need to know the name of the currently running test. +For example, you may be using the `SetUp()` method of your test fixture to set +the golden file name based on which test is running. The +[`TestInfo`](reference/testing.md#TestInfo) class has this information. + +To obtain a `TestInfo` object for the currently running test, call +`current_test_info()` on the [`UnitTest`](reference/testing.md#UnitTest) +singleton object: + +```c++ + // Gets information about the currently running test. + // Do NOT delete the returned object - it's managed by the UnitTest class. + const testing::TestInfo* const test_info = + testing::UnitTest::GetInstance()->current_test_info(); + + printf("We are in test %s of test suite %s.\n", + test_info->name(), + test_info->test_suite_name()); +``` + +`current_test_info()` returns a null pointer if no test is running. In +particular, you cannot find the test suite name in `SetUpTestSuite()`, +`TearDownTestSuite()` (where you know the test suite name implicitly), or +functions called from them. + +## Extending GoogleTest by Handling Test Events + +GoogleTest provides an **event listener API** to let you receive notifications +about the progress of a test program and test failures. The events you can +listen to include the start and end of the test program, a test suite, or a test +method, among others. You may use this API to augment or replace the standard +console output, replace the XML output, or provide a completely different form +of output, such as a GUI or a database. You can also use test events as +checkpoints to implement a resource leak checker, for example. + +### Defining Event Listeners + +To define a event listener, you subclass either +[`testing::TestEventListener`](reference/testing.md#TestEventListener) or +[`testing::EmptyTestEventListener`](reference/testing.md#EmptyTestEventListener) +The former is an (abstract) interface, where *each pure virtual method can be +overridden to handle a test event* (For example, when a test starts, the +`OnTestStart()` method will be called.). The latter provides an empty +implementation of all methods in the interface, such that a subclass only needs +to override the methods it cares about. + +When an event is fired, its context is passed to the handler function as an +argument. The following argument types are used: + +* UnitTest reflects the state of the entire test program, +* TestSuite has information about a test suite, which can contain one or more + tests, +* TestInfo contains the state of a test, and +* TestPartResult represents the result of a test assertion. + +An event handler function can examine the argument it receives to find out +interesting information about the event and the test program's state. + +Here's an example: + +```c++ + class MinimalistPrinter : public testing::EmptyTestEventListener { + // Called before a test starts. + void OnTestStart(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s starting.\n", + test_info.test_suite_name(), test_info.name()); + } + + // Called after a failed assertion or a SUCCESS(). + void OnTestPartResult(const testing::TestPartResult& test_part_result) override { + printf("%s in %s:%d\n%s\n", + test_part_result.failed() ? "*** Failure" : "Success", + test_part_result.file_name(), + test_part_result.line_number(), + test_part_result.summary()); + } + + // Called after a test ends. + void OnTestEnd(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s ending.\n", + test_info.test_suite_name(), test_info.name()); + } + }; +``` + +### Using Event Listeners + +To use the event listener you have defined, add an instance of it to the +GoogleTest event listener list (represented by class +[`TestEventListeners`](reference/testing.md#TestEventListeners) - note the "s" +at the end of the name) in your `main()` function, before calling +`RUN_ALL_TESTS()`: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Gets hold of the event listener list. + testing::TestEventListeners& listeners = + testing::UnitTest::GetInstance()->listeners(); + // Adds a listener to the end. GoogleTest takes the ownership. + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +} +``` + +There's only one problem: the default test result printer is still in effect, so +its output will mingle with the output from your minimalist printer. To suppress +the default printer, just release it from the event listener list and delete it. +You can do so by adding one line: + +```c++ + ... + delete listeners.Release(listeners.default_result_printer()); + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +``` + +Now, sit back and enjoy a completely different output from your tests. For more +details, see [sample9_unittest.cc]. + +[sample9_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample9_unittest.cc "Event listener example" + +You may append more than one listener to the list. When an `On*Start()` or +`OnTestPartResult()` event is fired, the listeners will receive it in the order +they appear in the list (since new listeners are added to the end of the list, +the default text printer and the default XML generator will receive the event +first). An `On*End()` event will be received by the listeners in the *reverse* +order. This allows output by listeners added later to be framed by output from +listeners added earlier. + +### Generating Failures in Listeners + +You may use failure-raising macros (`EXPECT_*()`, `ASSERT_*()`, `FAIL()`, etc) +when processing an event. There are some restrictions: + +1. You cannot generate any failure in `OnTestPartResult()` (otherwise it will + cause `OnTestPartResult()` to be called recursively). +2. A listener that handles `OnTestPartResult()` is not allowed to generate any + failure. + +When you add listeners to the listener list, you should put listeners that +handle `OnTestPartResult()` *before* listeners that can generate failures. This +ensures that failures generated by the latter are attributed to the right test +by the former. + +See [sample10_unittest.cc] for an example of a failure-raising listener. + +[sample10_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample10_unittest.cc "Failure-raising listener example" + +## Running Test Programs: Advanced Options + +GoogleTest test programs are ordinary executables. Once built, you can run them +directly and affect their behavior via the following environment variables +and/or command line flags. For the flags to work, your programs must call +`::testing::InitGoogleTest()` before calling `RUN_ALL_TESTS()`. + +To see a list of supported flags and their usage, please run your test program +with the `--help` flag. You can also use `-h`, `-?`, or `/?` for short. + +If an option is specified both by an environment variable and by a flag, the +latter takes precedence. + +### Selecting Tests + +#### Listing Test Names + +Sometimes it is necessary to list the available tests in a program before +running them so that a filter may be applied if needed. Including the flag +`--gtest_list_tests` overrides all other flags and lists tests in the following +format: + +```none +TestSuite1. + TestName1 + TestName2 +TestSuite2. + TestName +``` + +None of the tests listed are actually run if the flag is provided. There is no +corresponding environment variable for this flag. + +#### Running a Subset of the Tests + +By default, a GoogleTest program runs all tests the user has defined. Sometimes, +you want to run only a subset of the tests (e.g. for debugging or quickly +verifying a change). If you set the `GTEST_FILTER` environment variable or the +`--gtest_filter` flag to a filter string, GoogleTest will only run the tests +whose full names (in the form of `TestSuiteName.TestName`) match the filter. + +The format of a filter is a '`:`'-separated list of wildcard patterns (called +the *positive patterns*) optionally followed by a '`-`' and another +'`:`'-separated pattern list (called the *negative patterns*). A test matches +the filter if and only if it matches any of the positive patterns but does not +match any of the negative patterns. + +A pattern may contain `'*'` (matches any string) or `'?'` (matches any single +character). For convenience, the filter `'*-NegativePatterns'` can be also +written as `'-NegativePatterns'`. + +For example: + +* `./foo_test` Has no flag, and thus runs all its tests. +* `./foo_test --gtest_filter=*` Also runs everything, due to the single + match-everything `*` value. +* `./foo_test --gtest_filter=FooTest.*` Runs everything in test suite + `FooTest` . +* `./foo_test --gtest_filter=*Null*:*Constructor*` Runs any test whose full + name contains either `"Null"` or `"Constructor"` . +* `./foo_test --gtest_filter=-*DeathTest.*` Runs all non-death tests. +* `./foo_test --gtest_filter=FooTest.*-FooTest.Bar` Runs everything in test + suite `FooTest` except `FooTest.Bar`. +* `./foo_test --gtest_filter=FooTest.*:BarTest.*-FooTest.Bar:BarTest.Foo` Runs + everything in test suite `FooTest` except `FooTest.Bar` and everything in + test suite `BarTest` except `BarTest.Foo`. + +#### Stop test execution upon first failure + +By default, a GoogleTest program runs all tests the user has defined. In some +cases (e.g. iterative test development & execution) it may be desirable stop +test execution upon first failure (trading improved latency for completeness). +If `GTEST_FAIL_FAST` environment variable or `--gtest_fail_fast` flag is set, +the test runner will stop execution as soon as the first test failure is found. + +#### Temporarily Disabling Tests + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +If you need to disable all tests in a test suite, you can either add `DISABLED_` +to the front of the name of each test, or alternatively add it to the front of +the test suite name. + +For example, the following tests won't be run by GoogleTest, even though they +will still be compiled: + +```c++ +// Tests that Foo does Abc. +TEST(FooTest, DISABLED_DoesAbc) { ... } + +class DISABLED_BarTest : public testing::Test { ... }; + +// Tests that Bar does Xyz. +TEST_F(DISABLED_BarTest, DoesXyz) { ... } +``` + +{: .callout .note} +NOTE: This feature should only be used for temporary pain-relief. You still have +to fix the disabled tests at a later date. As a reminder, GoogleTest will print +a banner warning you if a test program contains any disabled tests. + +{: .callout .tip} +TIP: You can easily count the number of disabled tests you have using +`grep`. This number can be used as a metric for +improving your test quality. + +#### Temporarily Enabling Disabled Tests + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag or set the +`GTEST_ALSO_RUN_DISABLED_TESTS` environment variable to a value other than `0`. +You can combine this with the `--gtest_filter` flag to further select which +disabled tests to run. + +### Repeating the Tests + +Once in a while you'll run into a test whose result is hit-or-miss. Perhaps it +will fail only 1% of the time, making it rather hard to reproduce the bug under +a debugger. This can be a major source of frustration. + +The `--gtest_repeat` flag allows you to repeat all (or selected) test methods in +a program many times. Hopefully, a flaky test will eventually fail and give you +a chance to debug. Here's how to use it: + +```none +$ foo_test --gtest_repeat=1000 +Repeat foo_test 1000 times and don't stop at failures. + +$ foo_test --gtest_repeat=-1 +A negative count means repeating forever. + +$ foo_test --gtest_repeat=1000 --gtest_break_on_failure +Repeat foo_test 1000 times, stopping at the first failure. This +is especially useful when running under a debugger: when the test +fails, it will drop into the debugger and you can then inspect +variables and stacks. + +$ foo_test --gtest_repeat=1000 --gtest_filter=FooBar.* +Repeat the tests whose name matches the filter 1000 times. +``` + +If your test program contains +[global set-up/tear-down](#global-set-up-and-tear-down) code, it will be +repeated in each iteration as well, as the flakiness may be in it. To avoid +repeating global set-up/tear-down, specify +`--gtest_recreate_environments_when_repeating=false`{.nowrap}. + +You can also specify the repeat count by setting the `GTEST_REPEAT` environment +variable. + +### Shuffling the Tests + +You can specify the `--gtest_shuffle` flag (or set the `GTEST_SHUFFLE` +environment variable to `1`) to run the tests in a program in a random order. +This helps to reveal bad dependencies between tests. + +By default, GoogleTest uses a random seed calculated from the current time. +Therefore you'll get a different order every time. The console output includes +the random seed value, such that you can reproduce an order-related test failure +later. To specify the random seed explicitly, use the `--gtest_random_seed=SEED` +flag (or set the `GTEST_RANDOM_SEED` environment variable), where `SEED` is an +integer in the range [0, 99999]. The seed value 0 is special: it tells +GoogleTest to do the default behavior of calculating the seed from the current +time. + +If you combine this with `--gtest_repeat=N`, GoogleTest will pick a different +random seed and re-shuffle the tests in each iteration. + +### Distributing Test Functions to Multiple Machines + +If you have more than one machine you can use to run a test program, you might +want to run the test functions in parallel and get the result faster. We call +this technique *sharding*, where each machine is called a *shard*. + +GoogleTest is compatible with test sharding. To take advantage of this feature, +your test runner (not part of GoogleTest) needs to do the following: + +1. Allocate a number of machines (shards) to run the tests. +1. On each shard, set the `GTEST_TOTAL_SHARDS` environment variable to the total + number of shards. It must be the same for all shards. +1. On each shard, set the `GTEST_SHARD_INDEX` environment variable to the index + of the shard. Different shards must be assigned different indices, which + must be in the range `[0, GTEST_TOTAL_SHARDS - 1]`. +1. Run the same test program on all shards. When GoogleTest sees the above two + environment variables, it will select a subset of the test functions to run. + Across all shards, each test function in the program will be run exactly + once. +1. Wait for all shards to finish, then collect and report the results. + +Your project may have tests that were written without GoogleTest and thus don't +understand this protocol. In order for your test runner to figure out which test +supports sharding, it can set the environment variable `GTEST_SHARD_STATUS_FILE` +to a non-existent file path. If a test program supports sharding, it will create +this file to acknowledge that fact; otherwise it will not create it. The actual +contents of the file are not important at this time, although we may put some +useful information in it in the future. + +Here's an example to make it clear. Suppose you have a test program `foo_test` +that contains the following 5 test functions: + +``` +TEST(A, V) +TEST(A, W) +TEST(B, X) +TEST(B, Y) +TEST(B, Z) +``` + +Suppose you have 3 machines at your disposal. To run the test functions in +parallel, you would set `GTEST_TOTAL_SHARDS` to 3 on all machines, and set +`GTEST_SHARD_INDEX` to 0, 1, and 2 on the machines respectively. Then you would +run the same `foo_test` on each machine. + +GoogleTest reserves the right to change how the work is distributed across the +shards, but here's one possible scenario: + +* Machine #0 runs `A.V` and `B.X`. +* Machine #1 runs `A.W` and `B.Y`. +* Machine #2 runs `B.Z`. + +### Controlling Test Output + +#### Colored Terminal Output + +GoogleTest can use colors in its terminal output to make it easier to spot the +important information: + +
...
+[----------] 1 test from FooTest
+[ RUN      ] FooTest.DoesAbc
+[       OK ] FooTest.DoesAbc
+[----------] 2 tests from BarTest
+[ RUN      ] BarTest.HasXyzProperty
+[       OK ] BarTest.HasXyzProperty
+[ RUN      ] BarTest.ReturnsTrueOnSuccess
+... some error messages ...
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+...
+[==========] 30 tests from 14 test suites ran.
+[   PASSED ] 28 tests.
+[   FAILED ] 2 tests, listed below:
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+[   FAILED ] AnotherTest.DoesXyz
+
+ 2 FAILED TESTS
+
+ +You can set the `GTEST_COLOR` environment variable or the `--gtest_color` +command line flag to `yes`, `no`, or `auto` (the default) to enable colors, +disable colors, or let GoogleTest decide. When the value is `auto`, GoogleTest +will use colors if and only if the output goes to a terminal and (on non-Windows +platforms) the `TERM` environment variable is set to `xterm` or `xterm-color`. + +#### Suppressing test passes + +By default, GoogleTest prints 1 line of output for each test, indicating if it +passed or failed. To show only test failures, run the test program with +`--gtest_brief=1`, or set the GTEST_BRIEF environment variable to `1`. + +#### Suppressing the Elapsed Time + +By default, GoogleTest prints the time it takes to run each test. To disable +that, run the test program with the `--gtest_print_time=0` command line flag, or +set the GTEST_PRINT_TIME environment variable to `0`. + +#### Suppressing UTF-8 Text Output + +In case of assertion failures, GoogleTest prints expected and actual values of +type `string` both as hex-encoded strings as well as in readable UTF-8 text if +they contain valid non-ASCII UTF-8 characters. If you want to suppress the UTF-8 +text because, for example, you don't have an UTF-8 compatible output medium, run +the test program with `--gtest_print_utf8=0` or set the `GTEST_PRINT_UTF8` +environment variable to `0`. + +#### Generating an XML Report + +GoogleTest can emit a detailed XML report to a file in addition to its normal +textual output. The report contains the duration of each test, and thus can help +you identify slow tests. + +To generate the XML report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"xml:path_to_output_file"`, which will +create the file at the given location. You can also just use the string `"xml"`, +in which case the output can be found in the `test_detail.xml` file in the +current directory. + +If you specify a directory (for example, `"xml:output/directory/"` on Linux or +`"xml:output\directory\"` on Windows), GoogleTest will create the XML file in +that directory, named after the test executable (e.g. `foo_test.xml` for test +program `foo_test` or `foo_test.exe`). If the file already exists (perhaps left +over from a previous run), GoogleTest will pick a different name (e.g. +`foo_test_1.xml`) to avoid overwriting it. + +The report is based on the `junitreport` Ant task. Since that format was +originally intended for Java, a little interpretation is required to make it +apply to GoogleTest tests, as shown here: + +```xml + + + + + + + + + +``` + +* The root `` element corresponds to the entire test program. +* `` elements correspond to GoogleTest test suites. +* `` elements correspond to GoogleTest test functions. + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```xml + + + + + ... + ... + + + + + + + + + +``` + +Things to note: + +* The `tests` attribute of a `` or `` element tells how + many test functions the GoogleTest program or test suite contains, while the + `failures` attribute tells how many of them failed. + +* The `time` attribute expresses the duration of the test, test suite, or + entire test program in seconds. + +* The `timestamp` attribute records the local date and time of the test + execution. + +* The `file` and `line` attributes record the source file location, where the + test was defined. + +* Each `` element corresponds to a single failed GoogleTest + assertion. + +#### Generating a JSON Report + +GoogleTest can also emit a JSON report as an alternative format to XML. To +generate the JSON report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"json:path_to_output_file"`, which will +create the file at the given location. You can also just use the string +`"json"`, in which case the output can be found in the `test_detail.json` file +in the current directory. + +The report format conforms to the following JSON Schema: + +```json +{ + "$schema": "http://json-schema.org/schema#", + "type": "object", + "definitions": { + "TestCase": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "time": { "type": "string" }, + "testsuite": { + "type": "array", + "items": { + "$ref": "#/definitions/TestInfo" + } + } + } + }, + "TestInfo": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "file": { "type": "string" }, + "line": { "type": "integer" }, + "status": { + "type": "string", + "enum": ["RUN", "NOTRUN"] + }, + "time": { "type": "string" }, + "classname": { "type": "string" }, + "failures": { + "type": "array", + "items": { + "$ref": "#/definitions/Failure" + } + } + } + }, + "Failure": { + "type": "object", + "properties": { + "failures": { "type": "string" }, + "type": { "type": "string" } + } + } + }, + "properties": { + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "errors": { "type": "integer" }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "time": { "type": "string" }, + "name": { "type": "string" }, + "testsuites": { + "type": "array", + "items": { + "$ref": "#/definitions/TestCase" + } + } + } +} +``` + +The report uses the format that conforms to the following Proto3 using the +[JSON encoding](https://developers.google.com/protocol-buffers/docs/proto3#json): + +```proto +syntax = "proto3"; + +package googletest; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; + +message UnitTest { + int32 tests = 1; + int32 failures = 2; + int32 disabled = 3; + int32 errors = 4; + google.protobuf.Timestamp timestamp = 5; + google.protobuf.Duration time = 6; + string name = 7; + repeated TestCase testsuites = 8; +} + +message TestCase { + string name = 1; + int32 tests = 2; + int32 failures = 3; + int32 disabled = 4; + int32 errors = 5; + google.protobuf.Duration time = 6; + repeated TestInfo testsuite = 7; +} + +message TestInfo { + string name = 1; + string file = 6; + int32 line = 7; + enum Status { + RUN = 0; + NOTRUN = 1; + } + Status status = 2; + google.protobuf.Duration time = 3; + string classname = 4; + message Failure { + string failures = 1; + string type = 2; + } + repeated Failure failures = 5; +} +``` + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```json +{ + "tests": 3, + "failures": 1, + "errors": 0, + "time": "0.035s", + "timestamp": "2011-10-31T18:52:42Z", + "name": "AllTests", + "testsuites": [ + { + "name": "MathTest", + "tests": 2, + "failures": 1, + "errors": 0, + "time": "0.015s", + "testsuite": [ + { + "name": "Addition", + "file": "test.cpp", + "line": 1, + "status": "RUN", + "time": "0.007s", + "classname": "", + "failures": [ + { + "message": "Value of: add(1, 1)\n Actual: 3\nExpected: 2", + "type": "" + }, + { + "message": "Value of: add(1, -1)\n Actual: 1\nExpected: 0", + "type": "" + } + ] + }, + { + "name": "Subtraction", + "file": "test.cpp", + "line": 2, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + }, + { + "name": "LogicTest", + "tests": 1, + "failures": 0, + "errors": 0, + "time": "0.005s", + "testsuite": [ + { + "name": "NonContradiction", + "file": "test.cpp", + "line": 3, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + } + ] +} +``` + +{: .callout .important} +IMPORTANT: The exact format of the JSON document is subject to change. + +### Controlling How Failures Are Reported + +#### Detecting Test Premature Exit + +Google Test implements the _premature-exit-file_ protocol for test runners to +catch any kind of unexpected exits of test programs. Upon start, Google Test +creates the file which will be automatically deleted after all work has been +finished. Then, the test runner can check if this file exists. In case the file +remains undeleted, the inspected test has exited prematurely. + +This feature is enabled only if the `TEST_PREMATURE_EXIT_FILE` environment +variable has been set. + +#### Turning Assertion Failures into Break-Points + +When running test programs under a debugger, it's very convenient if the +debugger can catch an assertion failure and automatically drop into interactive +mode. GoogleTest's *break-on-failure* mode supports this behavior. + +To enable it, set the `GTEST_BREAK_ON_FAILURE` environment variable to a value +other than `0`. Alternatively, you can use the `--gtest_break_on_failure` +command line flag. + +#### Disabling Catching Test-Thrown Exceptions + +GoogleTest can be used either with or without exceptions enabled. If a test +throws a C++ exception or (on Windows) a structured exception (SEH), by default +GoogleTest catches it, reports it as a test failure, and continues with the next +test method. This maximizes the coverage of a test run. Also, on Windows an +uncaught exception will cause a pop-up window, so catching the exceptions allows +you to run the tests automatically. + +When debugging the test failures, however, you may instead want the exceptions +to be handled by the debugger, such that you can examine the call stack when an +exception is thrown. To achieve that, set the `GTEST_CATCH_EXCEPTIONS` +environment variable to `0`, or use the `--gtest_catch_exceptions=0` flag when +running the tests. + +### Sanitizer Integration + +The +[Undefined Behavior Sanitizer](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html), +[Address Sanitizer](https://github.com/google/sanitizers/wiki/AddressSanitizer), +and +[Thread Sanitizer](https://github.com/google/sanitizers/wiki/ThreadSanitizerCppManual) +all provide weak functions that you can override to trigger explicit failures +when they detect sanitizer errors, such as creating a reference from `nullptr`. +To override these functions, place definitions for them in a source file that +you compile as part of your main binary: + +``` +extern "C" { +void __ubsan_on_report() { + FAIL() << "Encountered an undefined behavior sanitizer error"; +} +void __asan_on_error() { + FAIL() << "Encountered an address sanitizer error"; +} +void __tsan_on_report() { + FAIL() << "Encountered a thread sanitizer error"; +} +} // extern "C" +``` + +After compiling your project with one of the sanitizers enabled, if a particular +test triggers a sanitizer error, GoogleTest will report that it failed. diff --git a/third_party/googletest/docs/assets/css/style.scss b/third_party/googletest/docs/assets/css/style.scss new file mode 100644 index 0000000..bb30f41 --- /dev/null +++ b/third_party/googletest/docs/assets/css/style.scss @@ -0,0 +1,5 @@ +--- +--- + +@import "jekyll-theme-primer"; +@import "main"; diff --git a/third_party/googletest/docs/community_created_documentation.md b/third_party/googletest/docs/community_created_documentation.md new file mode 100644 index 0000000..4569075 --- /dev/null +++ b/third_party/googletest/docs/community_created_documentation.md @@ -0,0 +1,7 @@ +# Community-Created Documentation + +The following is a list, in no particular order, of links to documentation +created by the Googletest community. + +* [Googlemock Insights](https://github.com/ElectricRCAircraftGuy/eRCaGuy_dotfiles/blob/master/googletest/insights.md), + by [ElectricRCAircraftGuy](https://github.com/ElectricRCAircraftGuy) diff --git a/third_party/googletest/docs/faq.md b/third_party/googletest/docs/faq.md new file mode 100644 index 0000000..1928097 --- /dev/null +++ b/third_party/googletest/docs/faq.md @@ -0,0 +1,692 @@ +# GoogleTest FAQ + +## Why should test suite names and test names not contain underscore? + +{: .callout .note} +Note: GoogleTest reserves underscore (`_`) for special purpose keywords, such as +[the `DISABLED_` prefix](advanced.md#temporarily-disabling-tests), in addition +to the following rationale. + +Underscore (`_`) is special, as C++ reserves the following to be used by the +compiler and the standard library: + +1. any identifier that starts with an `_` followed by an upper-case letter, and +2. any identifier that contains two consecutive underscores (i.e. `__`) + *anywhere* in its name. + +User code is *prohibited* from using such identifiers. + +Now let's look at what this means for `TEST` and `TEST_F`. + +Currently `TEST(TestSuiteName, TestName)` generates a class named +`TestSuiteName_TestName_Test`. What happens if `TestSuiteName` or `TestName` +contains `_`? + +1. If `TestSuiteName` starts with an `_` followed by an upper-case letter (say, + `_Foo`), we end up with `_Foo_TestName_Test`, which is reserved and thus + invalid. +2. If `TestSuiteName` ends with an `_` (say, `Foo_`), we get + `Foo__TestName_Test`, which is invalid. +3. If `TestName` starts with an `_` (say, `_Bar`), we get + `TestSuiteName__Bar_Test`, which is invalid. +4. If `TestName` ends with an `_` (say, `Bar_`), we get + `TestSuiteName_Bar__Test`, which is invalid. + +So clearly `TestSuiteName` and `TestName` cannot start or end with `_` +(Actually, `TestSuiteName` can start with `_` -- as long as the `_` isn't +followed by an upper-case letter. But that's getting complicated. So for +simplicity we just say that it cannot start with `_`.). + +It may seem fine for `TestSuiteName` and `TestName` to contain `_` in the +middle. However, consider this: + +```c++ +TEST(Time, Flies_Like_An_Arrow) { ... } +TEST(Time_Flies, Like_An_Arrow) { ... } +``` + +Now, the two `TEST`s will both generate the same class +(`Time_Flies_Like_An_Arrow_Test`). That's not good. + +So for simplicity, we just ask the users to avoid `_` in `TestSuiteName` and +`TestName`. The rule is more constraining than necessary, but it's simple and +easy to remember. It also gives GoogleTest some wiggle room in case its +implementation needs to change in the future. + +If you violate the rule, there may not be immediate consequences, but your test +may (just may) break with a new compiler (or a new version of the compiler you +are using) or with a new version of GoogleTest. Therefore it's best to follow +the rule. + +## Why does GoogleTest support `EXPECT_EQ(NULL, ptr)` and `ASSERT_EQ(NULL, ptr)` but not `EXPECT_NE(NULL, ptr)` and `ASSERT_NE(NULL, ptr)`? + +First of all, you can use `nullptr` with each of these macros, e.g. +`EXPECT_EQ(ptr, nullptr)`, `EXPECT_NE(ptr, nullptr)`, `ASSERT_EQ(ptr, nullptr)`, +`ASSERT_NE(ptr, nullptr)`. This is the preferred syntax in the style guide +because `nullptr` does not have the type problems that `NULL` does. + +Due to some peculiarity of C++, it requires some non-trivial template meta +programming tricks to support using `NULL` as an argument of the `EXPECT_XX()` +and `ASSERT_XX()` macros. Therefore we only do it where it's most needed +(otherwise we make the implementation of GoogleTest harder to maintain and more +error-prone than necessary). + +Historically, the `EXPECT_EQ()` macro took the *expected* value as its first +argument and the *actual* value as the second, though this argument order is now +discouraged. It was reasonable that someone wanted +to write `EXPECT_EQ(NULL, some_expression)`, and this indeed was requested +several times. Therefore we implemented it. + +The need for `EXPECT_NE(NULL, ptr)` wasn't nearly as strong. When the assertion +fails, you already know that `ptr` must be `NULL`, so it doesn't add any +information to print `ptr` in this case. That means `EXPECT_TRUE(ptr != NULL)` +works just as well. + +If we were to support `EXPECT_NE(NULL, ptr)`, for consistency we'd have to +support `EXPECT_NE(ptr, NULL)` as well. This means using the template meta +programming tricks twice in the implementation, making it even harder to +understand and maintain. We believe the benefit doesn't justify the cost. + +Finally, with the growth of the gMock matcher library, we are encouraging people +to use the unified `EXPECT_THAT(value, matcher)` syntax more often in tests. One +significant advantage of the matcher approach is that matchers can be easily +combined to form new matchers, while the `EXPECT_NE`, etc, macros cannot be +easily combined. Therefore we want to invest more in the matchers than in the +`EXPECT_XX()` macros. + +## I need to test that different implementations of an interface satisfy some common requirements. Should I use typed tests or value-parameterized tests? + +For testing various implementations of the same interface, either typed tests or +value-parameterized tests can get it done. It's really up to you the user to +decide which is more convenient for you, depending on your particular case. Some +rough guidelines: + +* Typed tests can be easier to write if instances of the different + implementations can be created the same way, modulo the type. For example, + if all these implementations have a public default constructor (such that + you can write `new TypeParam`), or if their factory functions have the same + form (e.g. `CreateInstance()`). +* Value-parameterized tests can be easier to write if you need different code + patterns to create different implementations' instances, e.g. `new Foo` vs + `new Bar(5)`. To accommodate for the differences, you can write factory + function wrappers and pass these function pointers to the tests as their + parameters. +* When a typed test fails, the default output includes the name of the type, + which can help you quickly identify which implementation is wrong. + Value-parameterized tests only show the number of the failed iteration by + default. You will need to define a function that returns the iteration name + and pass it as the third parameter to INSTANTIATE_TEST_SUITE_P to have more + useful output. +* When using typed tests, you need to make sure you are testing against the + interface type, not the concrete types (in other words, you want to make + sure `implicit_cast(my_concrete_impl)` works, not just that + `my_concrete_impl` works). It's less likely to make mistakes in this area + when using value-parameterized tests. + +I hope I didn't confuse you more. :-) If you don't mind, I'd suggest you to give +both approaches a try. Practice is a much better way to grasp the subtle +differences between the two tools. Once you have some concrete experience, you +can much more easily decide which one to use the next time. + +## I got some run-time errors about invalid proto descriptors when using `ProtocolMessageEquals`. Help! + +{: .callout .note} +**Note:** `ProtocolMessageEquals` and `ProtocolMessageEquiv` are *deprecated* +now. Please use `EqualsProto`, etc instead. + +`ProtocolMessageEquals` and `ProtocolMessageEquiv` were redefined recently and +are now less tolerant of invalid protocol buffer definitions. In particular, if +you have a `foo.proto` that doesn't fully qualify the type of a protocol message +it references (e.g. `message` where it should be `message`), you +will now get run-time errors like: + +``` +... descriptor.cc:...] Invalid proto descriptor for file "path/to/foo.proto": +... descriptor.cc:...] blah.MyMessage.my_field: ".Bar" is not defined. +``` + +If you see this, your `.proto` file is broken and needs to be fixed by making +the types fully qualified. The new definition of `ProtocolMessageEquals` and +`ProtocolMessageEquiv` just happen to reveal your bug. + +## My death test modifies some state, but the change seems lost after the death test finishes. Why? + +Death tests (`EXPECT_DEATH`, etc) are executed in a sub-process s.t. the +expected crash won't kill the test program (i.e. the parent process). As a +result, any in-memory side effects they incur are observable in their respective +sub-processes, but not in the parent process. You can think of them as running +in a parallel universe, more or less. + +In particular, if you use mocking and the death test statement invokes some mock +methods, the parent process will think the calls have never occurred. Therefore, +you may want to move your `EXPECT_CALL` statements inside the `EXPECT_DEATH` +macro. + +## EXPECT_EQ(htonl(blah), blah_blah) generates weird compiler errors in opt mode. Is this a GoogleTest bug? + +Actually, the bug is in `htonl()`. + +According to `'man htonl'`, `htonl()` is a *function*, which means it's valid to +use `htonl` as a function pointer. However, in opt mode `htonl()` is defined as +a *macro*, which breaks this usage. + +Worse, the macro definition of `htonl()` uses a `gcc` extension and is *not* +standard C++. That hacky implementation has some ad hoc limitations. In +particular, it prevents you from writing `Foo()`, where `Foo` +is a template that has an integral argument. + +The implementation of `EXPECT_EQ(a, b)` uses `sizeof(... a ...)` inside a +template argument, and thus doesn't compile in opt mode when `a` contains a call +to `htonl()`. It is difficult to make `EXPECT_EQ` bypass the `htonl()` bug, as +the solution must work with different compilers on various platforms. + +## The compiler complains about "undefined references" to some static const member variables, but I did define them in the class body. What's wrong? + +If your class has a static data member: + +```c++ +// foo.h +class Foo { + ... + static const int kBar = 100; +}; +``` + +You also need to define it *outside* of the class body in `foo.cc`: + +```c++ +const int Foo::kBar; // No initializer here. +``` + +Otherwise your code is **invalid C++**, and may break in unexpected ways. In +particular, using it in GoogleTest comparison assertions (`EXPECT_EQ`, etc) will +generate an "undefined reference" linker error. The fact that "it used to work" +doesn't mean it's valid. It just means that you were lucky. :-) + +If the declaration of the static data member is `constexpr` then it is +implicitly an `inline` definition, and a separate definition in `foo.cc` is not +needed: + +```c++ +// foo.h +class Foo { + ... + static constexpr int kBar = 100; // Defines kBar, no need to do it in foo.cc. +}; +``` + +## Can I derive a test fixture from another? + +Yes. + +Each test fixture has a corresponding and same named test suite. This means only +one test suite can use a particular fixture. Sometimes, however, multiple test +cases may want to use the same or slightly different fixtures. For example, you +may want to make sure that all of a GUI library's test suites don't leak +important system resources like fonts and brushes. + +In GoogleTest, you share a fixture among test suites by putting the shared logic +in a base test fixture, then deriving from that base a separate fixture for each +test suite that wants to use this common logic. You then use `TEST_F()` to write +tests using each derived fixture. + +Typically, your code looks like this: + +```c++ +// Defines a base test fixture. +class BaseTest : public ::testing::Test { + protected: + ... +}; + +// Derives a fixture FooTest from BaseTest. +class FooTest : public BaseTest { + protected: + void SetUp() override { + BaseTest::SetUp(); // Sets up the base fixture first. + ... additional set-up work ... + } + + void TearDown() override { + ... clean-up work for FooTest ... + BaseTest::TearDown(); // Remember to tear down the base fixture + // after cleaning up FooTest! + } + + ... functions and variables for FooTest ... +}; + +// Tests that use the fixture FooTest. +TEST_F(FooTest, Bar) { ... } +TEST_F(FooTest, Baz) { ... } + +... additional fixtures derived from BaseTest ... +``` + +If necessary, you can continue to derive test fixtures from a derived fixture. +GoogleTest has no limit on how deep the hierarchy can be. + +For a complete example using derived test fixtures, see +[sample5_unittest.cc](https://github.com/google/googletest/blob/main/googletest/samples/sample5_unittest.cc). + +## My compiler complains "void value not ignored as it ought to be." What does this mean? + +You're probably using an `ASSERT_*()` in a function that doesn't return `void`. +`ASSERT_*()` can only be used in `void` functions, due to exceptions being +disabled by our build system. Please see more details +[here](advanced.md#assertion-placement). + +## My death test hangs (or seg-faults). How do I fix it? + +In GoogleTest, death tests are run in a child process and the way they work is +delicate. To write death tests you really need to understand how they work—see +the details at [Death Assertions](reference/assertions.md#death) in the +Assertions Reference. + +In particular, death tests don't like having multiple threads in the parent +process. So the first thing you can try is to eliminate creating threads outside +of `EXPECT_DEATH()`. For example, you may want to use mocks or fake objects +instead of real ones in your tests. + +Sometimes this is impossible as some library you must use may be creating +threads before `main()` is even reached. In this case, you can try to minimize +the chance of conflicts by either moving as many activities as possible inside +`EXPECT_DEATH()` (in the extreme case, you want to move everything inside), or +leaving as few things as possible in it. Also, you can try to set the death test +style to `"threadsafe"`, which is safer but slower, and see if it helps. + +If you go with thread-safe death tests, remember that they rerun the test +program from the beginning in the child process. Therefore make sure your +program can run side-by-side with itself and is deterministic. + +In the end, this boils down to good concurrent programming. You have to make +sure that there are no race conditions or deadlocks in your program. No silver +bullet - sorry! + +## Should I use the constructor/destructor of the test fixture or SetUp()/TearDown()? {#CtorVsSetUp} + +The first thing to remember is that GoogleTest does **not** reuse the same test +fixture object across multiple tests. For each `TEST_F`, GoogleTest will create +a **fresh** test fixture object, immediately call `SetUp()`, run the test body, +call `TearDown()`, and then delete the test fixture object. + +When you need to write per-test set-up and tear-down logic, you have the choice +between using the test fixture constructor/destructor or `SetUp()/TearDown()`. +The former is usually preferred, as it has the following benefits: + +* By initializing a member variable in the constructor, we have the option to + make it `const`, which helps prevent accidental changes to its value and + makes the tests more obviously correct. +* In case we need to subclass the test fixture class, the subclass' + constructor is guaranteed to call the base class' constructor *first*, and + the subclass' destructor is guaranteed to call the base class' destructor + *afterward*. With `SetUp()/TearDown()`, a subclass may make the mistake of + forgetting to call the base class' `SetUp()/TearDown()` or call them at the + wrong time. + +You may still want to use `SetUp()/TearDown()` in the following cases: + +* C++ does not allow virtual function calls in constructors and destructors. + You can call a method declared as virtual, but it will not use dynamic + dispatch. It will use the definition from the class the constructor of which + is currently executing. This is because calling a virtual method before the + derived class constructor has a chance to run is very dangerous - the + virtual method might operate on uninitialized data. Therefore, if you need + to call a method that will be overridden in a derived class, you have to use + `SetUp()/TearDown()`. +* In the body of a constructor (or destructor), it's not possible to use the + `ASSERT_xx` macros. Therefore, if the set-up operation could cause a fatal + test failure that should prevent the test from running, it's necessary to + use `abort` and abort the whole test + executable, or to use `SetUp()` instead of a constructor. +* If the tear-down operation could throw an exception, you must use + `TearDown()` as opposed to the destructor, as throwing in a destructor leads + to undefined behavior and usually will kill your program right away. Note + that many standard libraries (like STL) may throw when exceptions are + enabled in the compiler. Therefore you should prefer `TearDown()` if you + want to write portable tests that work with or without exceptions. +* The GoogleTest team is considering making the assertion macros throw on + platforms where exceptions are enabled (e.g. Windows, Mac OS, and Linux + client-side), which will eliminate the need for the user to propagate + failures from a subroutine to its caller. Therefore, you shouldn't use + GoogleTest assertions in a destructor if your code could run on such a + platform. + +## The compiler complains "no matching function to call" when I use ASSERT_PRED*. How do I fix it? + +See details for [`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the +Assertions Reference. + +## My compiler complains about "ignoring return value" when I call RUN_ALL_TESTS(). Why? + +Some people had been ignoring the return value of `RUN_ALL_TESTS()`. That is, +instead of + +```c++ + return RUN_ALL_TESTS(); +``` + +they write + +```c++ + RUN_ALL_TESTS(); +``` + +This is **wrong and dangerous**. The testing services needs to see the return +value of `RUN_ALL_TESTS()` in order to determine if a test has passed. If your +`main()` function ignores it, your test will be considered successful even if it +has a GoogleTest assertion failure. Very bad. + +We have decided to fix this (thanks to Michael Chastain for the idea). Now, your +code will no longer be able to ignore `RUN_ALL_TESTS()` when compiled with +`gcc`. If you do so, you'll get a compiler error. + +If you see the compiler complaining about you ignoring the return value of +`RUN_ALL_TESTS()`, the fix is simple: just make sure its value is used as the +return value of `main()`. + +But how could we introduce a change that breaks existing tests? Well, in this +case, the code was already broken in the first place, so we didn't break it. :-) + +## My compiler complains that a constructor (or destructor) cannot return a value. What's going on? + +Due to a peculiarity of C++, in order to support the syntax for streaming +messages to an `ASSERT_*`, e.g. + +```c++ + ASSERT_EQ(1, Foo()) << "blah blah" << foo; +``` + +we had to give up using `ASSERT*` and `FAIL*` (but not `EXPECT*` and +`ADD_FAILURE*`) in constructors and destructors. The workaround is to move the +content of your constructor/destructor to a private void member function, or +switch to `EXPECT_*()` if that works. This +[section](advanced.md#assertion-placement) in the user's guide explains it. + +## My SetUp() function is not called. Why? + +C++ is case-sensitive. Did you spell it as `Setup()`? + +Similarly, sometimes people spell `SetUpTestSuite()` as `SetupTestSuite()` and +wonder why it's never called. + +## I have several test suites which share the same test fixture logic, do I have to define a new test fixture class for each of them? This seems pretty tedious. + +You don't have to. Instead of + +```c++ +class FooTest : public BaseTest {}; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +class BarTest : public BaseTest {}; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +you can simply `typedef` the test fixtures: + +```c++ +typedef BaseTest FooTest; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +typedef BaseTest BarTest; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +## GoogleTest output is buried in a whole bunch of LOG messages. What do I do? + +The GoogleTest output is meant to be a concise and human-friendly report. If +your test generates textual output itself, it will mix with the GoogleTest +output, making it hard to read. However, there is an easy solution to this +problem. + +Since `LOG` messages go to stderr, we decided to let GoogleTest output go to +stdout. This way, you can easily separate the two using redirection. For +example: + +```shell +$ ./my_test > gtest_output.txt +``` + +## Why should I prefer test fixtures over global variables? + +There are several good reasons: + +1. It's likely your test needs to change the states of its global variables. + This makes it difficult to keep side effects from escaping one test and + contaminating others, making debugging difficult. By using fixtures, each + test has a fresh set of variables that's different (but with the same + names). Thus, tests are kept independent of each other. +2. Global variables pollute the global namespace. +3. Test fixtures can be reused via subclassing, which cannot be done easily + with global variables. This is useful if many test suites have something in + common. + +## What can the statement argument in ASSERT_DEATH() be? + +`ASSERT_DEATH(statement, matcher)` (or any death assertion macro) can be used +wherever *`statement`* is valid. So basically *`statement`* can be any C++ +statement that makes sense in the current context. In particular, it can +reference global and/or local variables, and can be: + +* a simple function call (often the case), +* a complex expression, or +* a compound statement. + +Some examples are shown here: + +```c++ +// A death test can be a simple function call. +TEST(MyDeathTest, FunctionCall) { + ASSERT_DEATH(Xyz(5), "Xyz failed"); +} + +// Or a complex expression that references variables and functions. +TEST(MyDeathTest, ComplexExpression) { + const bool c = Condition(); + ASSERT_DEATH((c ? Func1(0) : object2.Method("test")), + "(Func1|Method) failed"); +} + +// Death assertions can be used anywhere in a function. In +// particular, they can be inside a loop. +TEST(MyDeathTest, InsideLoop) { + // Verifies that Foo(0), Foo(1), ..., and Foo(4) all die. + for (int i = 0; i < 5; i++) { + EXPECT_DEATH_M(Foo(i), "Foo has \\d+ errors", + ::testing::Message() << "where i is " << i); + } +} + +// A death assertion can contain a compound statement. +TEST(MyDeathTest, CompoundStatement) { + // Verifies that at lease one of Bar(0), Bar(1), ..., and + // Bar(4) dies. + ASSERT_DEATH({ + for (int i = 0; i < 5; i++) { + Bar(i); + } + }, + "Bar has \\d+ errors"); +} +``` + +## I have a fixture class `FooTest`, but `TEST_F(FooTest, Bar)` gives me error ``"no matching function for call to `FooTest::FooTest()'"``. Why? + +GoogleTest needs to be able to create objects of your test fixture class, so it +must have a default constructor. Normally the compiler will define one for you. +However, there are cases where you have to define your own: + +* If you explicitly declare a non-default constructor for class `FooTest` + (`DISALLOW_EVIL_CONSTRUCTORS()` does this), then you need to define a + default constructor, even if it would be empty. +* If `FooTest` has a const non-static data member, then you have to define the + default constructor *and* initialize the const member in the initializer + list of the constructor. (Early versions of `gcc` doesn't force you to + initialize the const member. It's a bug that has been fixed in `gcc 4`.) + +## Why does ASSERT_DEATH complain about previous threads that were already joined? + +With the Linux pthread library, there is no turning back once you cross the line +from a single thread to multiple threads. The first time you create a thread, a +manager thread is created in addition, so you get 3, not 2, threads. Later when +the thread you create joins the main thread, the thread count decrements by 1, +but the manager thread will never be killed, so you still have 2 threads, which +means you cannot safely run a death test. + +The new NPTL thread library doesn't suffer from this problem, as it doesn't +create a manager thread. However, if you don't control which machine your test +runs on, you shouldn't depend on this. + +## Why does GoogleTest require the entire test suite, instead of individual tests, to be named *DeathTest when it uses ASSERT_DEATH? + +GoogleTest does not interleave tests from different test suites. That is, it +runs all tests in one test suite first, and then runs all tests in the next test +suite, and so on. GoogleTest does this because it needs to set up a test suite +before the first test in it is run, and tear it down afterwards. Splitting up +the test case would require multiple set-up and tear-down processes, which is +inefficient and makes the semantics unclean. + +If we were to determine the order of tests based on test name instead of test +case name, then we would have a problem with the following situation: + +```c++ +TEST_F(FooTest, AbcDeathTest) { ... } +TEST_F(FooTest, Uvw) { ... } + +TEST_F(BarTest, DefDeathTest) { ... } +TEST_F(BarTest, Xyz) { ... } +``` + +Since `FooTest.AbcDeathTest` needs to run before `BarTest.Xyz`, and we don't +interleave tests from different test suites, we need to run all tests in the +`FooTest` case before running any test in the `BarTest` case. This contradicts +with the requirement to run `BarTest.DefDeathTest` before `FooTest.Uvw`. + +## But I don't like calling my entire test suite \*DeathTest when it contains both death tests and non-death tests. What do I do? + +You don't have to, but if you like, you may split up the test suite into +`FooTest` and `FooDeathTest`, where the names make it clear that they are +related: + +```c++ +class FooTest : public ::testing::Test { ... }; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +using FooDeathTest = FooTest; + +TEST_F(FooDeathTest, Uvw) { ... EXPECT_DEATH(...) ... } +TEST_F(FooDeathTest, Xyz) { ... ASSERT_DEATH(...) ... } +``` + +## GoogleTest prints the LOG messages in a death test's child process only when the test fails. How can I see the LOG messages when the death test succeeds? + +Printing the LOG messages generated by the statement inside `EXPECT_DEATH()` +makes it harder to search for real problems in the parent's log. Therefore, +GoogleTest only prints them when the death test has failed. + +If you really need to see such LOG messages, a workaround is to temporarily +break the death test (e.g. by changing the regex pattern it is expected to +match). Admittedly, this is a hack. We'll consider a more permanent solution +after the fork-and-exec-style death tests are implemented. + +## The compiler complains about `no match for 'operator<<'` when I use an assertion. What gives? + +If you use a user-defined type `FooType` in an assertion, you must make sure +there is an `std::ostream& operator<<(std::ostream&, const FooType&)` function +defined such that we can print a value of `FooType`. + +In addition, if `FooType` is declared in a name space, the `<<` operator also +needs to be defined in the *same* name space. See +[Tip of the Week #49](http://abseil.io/tips/49) for details. + +## How do I suppress the memory leak messages on Windows? + +Since the statically initialized GoogleTest singleton requires allocations on +the heap, the Visual C++ memory leak detector will report memory leaks at the +end of the program run. The easiest way to avoid this is to use the +`_CrtMemCheckpoint` and `_CrtMemDumpAllObjectsSince` calls to not report any +statically initialized heap objects. See MSDN for more details and additional +heap check/debug routines. + +## How can my code detect if it is running in a test? + +If you write code that sniffs whether it's running in a test and does different +things accordingly, you are leaking test-only logic into production code and +there is no easy way to ensure that the test-only code paths aren't run by +mistake in production. Such cleverness also leads to +[Heisenbugs](https://en.wikipedia.org/wiki/Heisenbug). Therefore we strongly +advise against the practice, and GoogleTest doesn't provide a way to do it. + +In general, the recommended way to cause the code to behave differently under +test is [Dependency Injection](http://en.wikipedia.org/wiki/Dependency_injection). You can inject +different functionality from the test and from the production code. Since your +production code doesn't link in the for-test logic at all (the +[`testonly`](http://docs.bazel.build/versions/master/be/common-definitions.html#common.testonly) attribute for BUILD targets helps to ensure +that), there is no danger in accidentally running it. + +However, if you *really*, *really*, *really* have no choice, and if you follow +the rule of ending your test program names with `_test`, you can use the +*horrible* hack of sniffing your executable name (`argv[0]` in `main()`) to know +whether the code is under test. + +## How do I temporarily disable a test? + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag. + +## Is it OK if I have two separate `TEST(Foo, Bar)` test methods defined in different namespaces? + +Yes. + +The rule is **all test methods in the same test suite must use the same fixture +class.** This means that the following is **allowed** because both tests use the +same fixture class (`::testing::Test`). + +```c++ +namespace foo { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` + +However, the following code is **not allowed** and will produce a runtime error +from GoogleTest because the test methods are using different test fixture +classes with the same test suite name. + +```c++ +namespace foo { +class CoolTest : public ::testing::Test {}; // Fixture foo::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +class CoolTest : public ::testing::Test {}; // Fixture: bar::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` diff --git a/third_party/googletest/docs/gmock_cheat_sheet.md b/third_party/googletest/docs/gmock_cheat_sheet.md new file mode 100644 index 0000000..ddafaaa --- /dev/null +++ b/third_party/googletest/docs/gmock_cheat_sheet.md @@ -0,0 +1,241 @@ +# gMock Cheat Sheet + +## Defining a Mock Class + +### Mocking a Normal Class {#MockClass} + +Given + +```cpp +class Foo { + public: + virtual ~Foo(); + virtual int GetSize() const = 0; + virtual string Describe(const char* name) = 0; + virtual string Describe(int type) = 0; + virtual bool Process(Bar elem, int count) = 0; +}; +``` + +(note that `~Foo()` **must** be virtual) we can define its mock as + +```cpp +#include + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(string, Describe, (const char* name), (override)); + MOCK_METHOD(string, Describe, (int type), (override)); + MOCK_METHOD(bool, Process, (Bar elem, int count), (override)); +}; +``` + +To create a "nice" mock, which ignores all uninteresting calls, a "naggy" mock, +which warns on all uninteresting calls, or a "strict" mock, which treats them as +failures: + +```cpp +using ::testing::NiceMock; +using ::testing::NaggyMock; +using ::testing::StrictMock; + +NiceMock nice_foo; // The type is a subclass of MockFoo. +NaggyMock naggy_foo; // The type is a subclass of MockFoo. +StrictMock strict_foo; // The type is a subclass of MockFoo. +``` + +{: .callout .note} +**Note:** A mock object is currently naggy by default. We may make it nice by +default in the future. + +### Mocking a Class Template {#MockTemplate} + +Class templates can be mocked just like any class. + +To mock + +```cpp +template +class StackInterface { + public: + virtual ~StackInterface(); + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; +``` + +(note that all member functions that are mocked, including `~StackInterface()` +**must** be virtual). + +```cpp +template +class MockStack : public StackInterface { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Specifying Calling Conventions for Mock Functions + +If your mock function doesn't use the default calling convention, you can +specify it by adding `Calltype(convention)` to `MOCK_METHOD`'s 4th parameter. +For example, + +```cpp + MOCK_METHOD(bool, Foo, (int n), (Calltype(STDMETHODCALLTYPE))); + MOCK_METHOD(int, Bar, (double x, double y), + (const, Calltype(STDMETHODCALLTYPE))); +``` + +where `STDMETHODCALLTYPE` is defined by `` on Windows. + +## Using Mocks in Tests {#UsingMocks} + +The typical work flow is: + +1. Import the gMock names you need to use. All gMock symbols are in the + `testing` namespace unless they are macros or otherwise noted. +2. Create the mock objects. +3. Optionally, set the default actions of the mock objects. +4. Set your expectations on the mock objects (How will they be called? What + will they do?). +5. Exercise code that uses the mock objects; if necessary, check the result + using googletest assertions. +6. When a mock object is destructed, gMock automatically verifies that all + expectations on it have been satisfied. + +Here's an example: + +```cpp +using ::testing::Return; // #1 + +TEST(BarTest, DoesThis) { + MockFoo foo; // #2 + + ON_CALL(foo, GetSize()) // #3 + .WillByDefault(Return(1)); + // ... other default actions ... + + EXPECT_CALL(foo, Describe(5)) // #4 + .Times(3) + .WillRepeatedly(Return("Category 5")); + // ... other expectations ... + + EXPECT_EQ(MyProductionFunction(&foo), "good"); // #5 +} // #6 +``` + +## Setting Default Actions {#OnCall} + +gMock has a **built-in default action** for any function that returns `void`, +`bool`, a numeric value, or a pointer. In C++11, it will additionally returns +the default-constructed value, if one exists for the given type. + +To customize the default action for functions with return type `T`, use +[`DefaultValue`](reference/mocking.md#DefaultValue). For example: + +```cpp + // Sets the default action for return type std::unique_ptr to + // creating a new Buzz every time. + DefaultValue>::SetFactory( + [] { return std::make_unique(AccessLevel::kInternal); }); + + // When this fires, the default action of MakeBuzz() will run, which + // will return a new Buzz object. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")).Times(AnyNumber()); + + auto buzz1 = mock_buzzer_.MakeBuzz("hello"); + auto buzz2 = mock_buzzer_.MakeBuzz("hello"); + EXPECT_NE(buzz1, nullptr); + EXPECT_NE(buzz2, nullptr); + EXPECT_NE(buzz1, buzz2); + + // Resets the default action for return type std::unique_ptr, + // to avoid interfere with other tests. + DefaultValue>::Clear(); +``` + +To customize the default action for a particular method of a specific mock +object, use [`ON_CALL`](reference/mocking.md#ON_CALL). `ON_CALL` has a similar +syntax to `EXPECT_CALL`, but it is used for setting default behaviors when you +do not require that the mock method is called. See +[Knowing When to Expect](gmock_cook_book.md#UseOnCall) for a more detailed +discussion. + +## Setting Expectations {#ExpectCall} + +See [`EXPECT_CALL`](reference/mocking.md#EXPECT_CALL) in the Mocking Reference. + +## Matchers {#MatcherList} + +See the [Matchers Reference](reference/matchers.md). + +## Actions {#ActionList} + +See the [Actions Reference](reference/actions.md). + +## Cardinalities {#CardinalityList} + +See the [`Times` clause](reference/mocking.md#EXPECT_CALL.Times) of +`EXPECT_CALL` in the Mocking Reference. + +## Expectation Order + +By default, expectations can be matched in *any* order. If some or all +expectations must be matched in a given order, you can use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) or +[`InSequence` clause](reference/mocking.md#EXPECT_CALL.InSequence) of +`EXPECT_CALL`, or use an [`InSequence` object](reference/mocking.md#InSequence). + +## Verifying and Resetting a Mock + +gMock will verify the expectations on a mock object when it is destructed, or +you can do it earlier: + +```cpp +using ::testing::Mock; +... +// Verifies and removes the expectations on mock_obj; +// returns true if and only if successful. +Mock::VerifyAndClearExpectations(&mock_obj); +... +// Verifies and removes the expectations on mock_obj; +// also removes the default actions set by ON_CALL(); +// returns true if and only if successful. +Mock::VerifyAndClear(&mock_obj); +``` + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +You can also tell gMock that a mock object can be leaked and doesn't need to be +verified: + +```cpp +Mock::AllowLeak(&mock_obj); +``` + +## Mock Classes + +gMock defines a convenient mock class template + +```cpp +class MockFunction { + public: + MOCK_METHOD(R, Call, (A1, ..., An)); +}; +``` + +See this [recipe](gmock_cook_book.md#UsingCheckPoints) for one application of +it. + +## Flags + +| Flag | Description | +| :----------------------------- | :---------------------------------------- | +| `--gmock_catch_leaked_mocks=0` | Don't report leaked mock objects as failures. | +| `--gmock_verbose=LEVEL` | Sets the default verbosity level (`info`, `warning`, or `error`) of Google Mock messages. | diff --git a/third_party/googletest/docs/gmock_cook_book.md b/third_party/googletest/docs/gmock_cook_book.md new file mode 100644 index 0000000..da10918 --- /dev/null +++ b/third_party/googletest/docs/gmock_cook_book.md @@ -0,0 +1,4344 @@ +# gMock Cookbook + +You can find recipes for using gMock here. If you haven't yet, please read +[the dummy guide](gmock_for_dummies.md) first to make sure you understand the +basics. + +{: .callout .note} +**Note:** gMock lives in the `testing` name space. For readability, it is +recommended to write `using ::testing::Foo;` once in your file before using the +name `Foo` defined by gMock. We omit such `using` statements in this section for +brevity, but you should do it in your own code. + +## Creating Mock Classes + +Mock classes are defined as normal classes, using the `MOCK_METHOD` macro to +generate mocked methods. The macro gets 3 or 4 parameters: + +```cpp +class MyMock { + public: + MOCK_METHOD(ReturnType, MethodName, (Args...)); + MOCK_METHOD(ReturnType, MethodName, (Args...), (Specs...)); +}; +``` + +The first 3 parameters are simply the method declaration, split into 3 parts. +The 4th parameter accepts a closed list of qualifiers, which affect the +generated method: + +* **`const`** - Makes the mocked method a `const` method. Required if + overriding a `const` method. +* **`override`** - Marks the method with `override`. Recommended if overriding + a `virtual` method. +* **`noexcept`** - Marks the method with `noexcept`. Required if overriding a + `noexcept` method. +* **`Calltype(...)`** - Sets the call type for the method (e.g. to + `STDMETHODCALLTYPE`), useful in Windows. +* **`ref(...)`** - Marks the method with the reference qualification + specified. Required if overriding a method that has reference + qualifications. Eg `ref(&)` or `ref(&&)`. + +### Dealing with unprotected commas + +Unprotected commas, i.e. commas which are not surrounded by parentheses, prevent +`MOCK_METHOD` from parsing its arguments correctly: + +{: .bad} +```cpp +class MockFoo { + public: + MOCK_METHOD(std::pair, GetPair, ()); // Won't compile! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Won't compile! +}; +``` + +Solution 1 - wrap with parentheses: + +{: .good} +```cpp +class MockFoo { + public: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); +}; +``` + +Note that wrapping a return or argument type with parentheses is, in general, +invalid C++. `MOCK_METHOD` removes the parentheses. + +Solution 2 - define an alias: + +{: .good} +```cpp +class MockFoo { + public: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +### Mocking Private or Protected Methods + +You must always put a mock method definition (`MOCK_METHOD`) in a `public:` +section of the mock class, regardless of the method being mocked being `public`, +`protected`, or `private` in the base class. This allows `ON_CALL` and +`EXPECT_CALL` to reference the mock function from outside of the mock class. +(Yes, C++ allows a subclass to change the access level of a virtual function in +the base class.) Example: + +```cpp +class Foo { + public: + ... + virtual bool Transform(Gadget* g) = 0; + + protected: + virtual void Resume(); + + private: + virtual int GetTimeOut(); +}; + +class MockFoo : public Foo { + public: + ... + MOCK_METHOD(bool, Transform, (Gadget* g), (override)); + + // The following must be in the public section, even though the + // methods are protected or private in the base class. + MOCK_METHOD(void, Resume, (), (override)); + MOCK_METHOD(int, GetTimeOut, (), (override)); +}; +``` + +### Mocking Overloaded Methods + +You can mock overloaded functions as usual. No special attention is required: + +```cpp +class Foo { + ... + + // Must be virtual as we'll inherit from Foo. + virtual ~Foo(); + + // Overloaded on the types and/or numbers of arguments. + virtual int Add(Element x); + virtual int Add(int times, Element x); + + // Overloaded on the const-ness of this object. + virtual Bar& GetBar(); + virtual const Bar& GetBar() const; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(int, Add, (Element x), (override)); + MOCK_METHOD(int, Add, (int times, Element x), (override)); + + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; +``` + +{: .callout .note} +**Note:** if you don't mock all versions of the overloaded method, the compiler +will give you a warning about some methods in the base class being hidden. To +fix that, use `using` to bring them in scope: + +```cpp +class MockFoo : public Foo { + ... + using Foo::Add; + MOCK_METHOD(int, Add, (Element x), (override)); + // We don't want to mock int Add(int times, Element x); + ... +}; +``` + +### Mocking Class Templates + +You can mock class templates just like any class. + +```cpp +template +class StackInterface { + ... + // Must be virtual as we'll inherit from StackInterface. + virtual ~StackInterface(); + + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; + +template +class MockStack : public StackInterface { + ... + MOCK_METHOD(int, GetSize, (), (override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Mocking Non-virtual Methods {#MockingNonVirtualMethods} + +gMock can mock non-virtual functions to be used in Hi-perf dependency injection. + +In this case, instead of sharing a common base class with the real class, your +mock class will be *unrelated* to the real class, but contain methods with the +same signatures. The syntax for mocking non-virtual methods is the *same* as +mocking virtual methods (just don't add `override`): + +```cpp +// A simple packet stream class. None of its members is virtual. +class ConcretePacketStream { + public: + void AppendPacket(Packet* new_packet); + const Packet* GetPacket(size_t packet_number) const; + size_t NumberOfPackets() const; + ... +}; + +// A mock packet stream class. It inherits from no other, but defines +// GetPacket() and NumberOfPackets(). +class MockPacketStream { + public: + MOCK_METHOD(const Packet*, GetPacket, (size_t packet_number), (const)); + MOCK_METHOD(size_t, NumberOfPackets, (), (const)); + ... +}; +``` + +Note that the mock class doesn't define `AppendPacket()`, unlike the real class. +That's fine as long as the test doesn't need to call it. + +Next, you need a way to say that you want to use `ConcretePacketStream` in +production code, and use `MockPacketStream` in tests. Since the functions are +not virtual and the two classes are unrelated, you must specify your choice at +*compile time* (as opposed to run time). + +One way to do it is to templatize your code that needs to use a packet stream. +More specifically, you will give your code a template type argument for the type +of the packet stream. In production, you will instantiate your template with +`ConcretePacketStream` as the type argument. In tests, you will instantiate the +same template with `MockPacketStream`. For example, you may write: + +```cpp +template +void CreateConnection(PacketStream* stream) { ... } + +template +class PacketReader { + public: + void ReadPackets(PacketStream* stream, size_t packet_num); +}; +``` + +Then you can use `CreateConnection()` and +`PacketReader` in production code, and use +`CreateConnection()` and `PacketReader` in +tests. + +```cpp + MockPacketStream mock_stream; + EXPECT_CALL(mock_stream, ...)...; + .. set more expectations on mock_stream ... + PacketReader reader(&mock_stream); + ... exercise reader ... +``` + +### Mocking Free Functions + +It is not possible to directly mock a free function (i.e. a C-style function or +a static method). If you need to, you can rewrite your code to use an interface +(abstract class). + +Instead of calling a free function (say, `OpenFile`) directly, introduce an +interface for it and have a concrete subclass that calls the free function: + +```cpp +class FileInterface { + public: + ... + virtual bool Open(const char* path, const char* mode) = 0; +}; + +class File : public FileInterface { + public: + ... + bool Open(const char* path, const char* mode) override { + return OpenFile(path, mode); + } +}; +``` + +Your code should talk to `FileInterface` to open a file. Now it's easy to mock +out the function. + +This may seem like a lot of hassle, but in practice you often have multiple +related functions that you can put in the same interface, so the per-function +syntactic overhead will be much lower. + +If you are concerned about the performance overhead incurred by virtual +functions, and profiling confirms your concern, you can combine this with the +recipe for [mocking non-virtual methods](#MockingNonVirtualMethods). + +Alternatively, instead of introducing a new interface, you can rewrite your code +to accept a std::function instead of the free function, and then use +[MockFunction](#MockFunction) to mock the std::function. + +### Old-Style `MOCK_METHODn` Macros + +Before the generic `MOCK_METHOD` macro +[was introduced in 2018](https://github.com/google/googletest/commit/c5f08bf91944ce1b19bcf414fa1760e69d20afc2), +mocks where created using a family of macros collectively called `MOCK_METHODn`. +These macros are still supported, though migration to the new `MOCK_METHOD` is +recommended. + +The macros in the `MOCK_METHODn` family differ from `MOCK_METHOD`: + +* The general structure is `MOCK_METHODn(MethodName, ReturnType(Args))`, + instead of `MOCK_METHOD(ReturnType, MethodName, (Args))`. +* The number `n` must equal the number of arguments. +* When mocking a const method, one must use `MOCK_CONST_METHODn`. +* When mocking a class template, the macro name must be suffixed with `_T`. +* In order to specify the call type, the macro name must be suffixed with + `_WITH_CALLTYPE`, and the call type is the first macro argument. + +Old macros and their new equivalents: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Simple
OldMOCK_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method
OldMOCK_CONST_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method in a Class Template
OldMOCK_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method in a Class Template
OldMOCK_CONST_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method with Call Type
OldMOCK_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type
OldMOCK_CONST_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
Method with Call Type in a Class Template
OldMOCK_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type in a Class Template
OldMOCK_CONST_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
+ +### The Nice, the Strict, and the Naggy {#NiceStrictNaggy} + +If a mock method has no `EXPECT_CALL` spec but is called, we say that it's an +"uninteresting call", and the default action (which can be specified using +`ON_CALL()`) of the method will be taken. Currently, an uninteresting call will +also by default cause gMock to print a warning. + +However, sometimes you may want to ignore these uninteresting calls, and +sometimes you may want to treat them as errors. gMock lets you make the decision +on a per-mock-object basis. + +Suppose your test uses a mock class `MockFoo`: + +```cpp +TEST(...) { + MockFoo mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +If a method of `mock_foo` other than `DoThis()` is called, you will get a +warning. However, if you rewrite your test to use `NiceMock` instead, +you can suppress the warning: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +`NiceMock` is a subclass of `MockFoo`, so it can be used wherever +`MockFoo` is accepted. + +It also works if `MockFoo`'s constructor takes some arguments, as +`NiceMock` "inherits" `MockFoo`'s constructors: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo(5, "hi"); // Calls MockFoo(5, "hi"). + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +The usage of `StrictMock` is similar, except that it makes all uninteresting +calls failures: + +```cpp +using ::testing::StrictMock; + +TEST(...) { + StrictMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... + + // The test will fail if a method of mock_foo other than DoThis() + // is called. +} +``` + +{: .callout .note} +NOTE: `NiceMock` and `StrictMock` only affects *uninteresting* calls (calls of +*methods* with no expectations); they do not affect *unexpected* calls (calls of +methods with expectations, but they don't match). See +[Understanding Uninteresting vs Unexpected Calls](#uninteresting-vs-unexpected). + +There are some caveats though (sadly they are side effects of C++'s +limitations): + +1. `NiceMock` and `StrictMock` only work for mock methods + defined using the `MOCK_METHOD` macro **directly** in the `MockFoo` class. + If a mock method is defined in a **base class** of `MockFoo`, the "nice" or + "strict" modifier may not affect it, depending on the compiler. In + particular, nesting `NiceMock` and `StrictMock` (e.g. + `NiceMock >`) is **not** supported. +2. `NiceMock` and `StrictMock` may not work correctly if the + destructor of `MockFoo` is not virtual. We would like to fix this, but it + requires cleaning up existing tests. + +Finally, you should be **very cautious** about when to use naggy or strict +mocks, as they tend to make tests more brittle and harder to maintain. When you +refactor your code without changing its externally visible behavior, ideally you +shouldn't need to update any tests. If your code interacts with a naggy mock, +however, you may start to get spammed with warnings as the result of your +change. Worse, if your code interacts with a strict mock, your tests may start +to fail and you'll be forced to fix them. Our general recommendation is to use +nice mocks (not yet the default) most of the time, use naggy mocks (the current +default) when developing or debugging tests, and use strict mocks only as the +last resort. + +### Simplifying the Interface without Breaking Existing Code {#SimplerInterfaces} + +Sometimes a method has a long list of arguments that is mostly uninteresting. +For example: + +```cpp +class LogSink { + public: + ... + virtual void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, + const struct tm* tm_time, + const char* message, size_t message_len) = 0; +}; +``` + +This method's argument list is lengthy and hard to work with (the `message` +argument is not even 0-terminated). If we mock it as is, using the mock will be +awkward. If, however, we try to simplify this interface, we'll need to fix all +clients depending on it, which is often infeasible. + +The trick is to redispatch the method in the mock class: + +```cpp +class ScopedMockLog : public LogSink { + public: + ... + void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, const tm* tm_time, + const char* message, size_t message_len) override { + // We are only interested in the log severity, full file name, and + // log message. + Log(severity, full_filename, std::string(message, message_len)); + } + + // Implements the mock method: + // + // void Log(LogSeverity severity, + // const string& file_path, + // const string& message); + MOCK_METHOD(void, Log, + (LogSeverity severity, const string& file_path, + const string& message)); +}; +``` + +By defining a new mock method with a trimmed argument list, we make the mock +class more user-friendly. + +This technique may also be applied to make overloaded methods more amenable to +mocking. For example, when overloads have been used to implement default +arguments: + +```cpp +class MockTurtleFactory : public TurtleFactory { + public: + Turtle* MakeTurtle(int length, int weight) override { ... } + Turtle* MakeTurtle(int length, int weight, int speed) override { ... } + + // the above methods delegate to this one: + MOCK_METHOD(Turtle*, DoMakeTurtle, ()); +}; +``` + +This allows tests that don't care which overload was invoked to avoid specifying +argument matchers: + +```cpp +ON_CALL(factory, DoMakeTurtle) + .WillByDefault(Return(MakeMockTurtle())); +``` + +### Alternative to Mocking Concrete Classes + +Often you may find yourself using classes that don't implement interfaces. In +order to test your code that uses such a class (let's call it `Concrete`), you +may be tempted to make the methods of `Concrete` virtual and then mock it. + +Try not to do that. + +Making a non-virtual function virtual is a big decision. It creates an extension +point where subclasses can tweak your class' behavior. This weakens your control +on the class because now it's harder to maintain the class invariants. You +should make a function virtual only when there is a valid reason for a subclass +to override it. + +Mocking concrete classes directly is problematic as it creates a tight coupling +between the class and the tests - any small change in the class may invalidate +your tests and make test maintenance a pain. + +To avoid such problems, many programmers have been practicing "coding to +interfaces": instead of talking to the `Concrete` class, your code would define +an interface and talk to it. Then you implement that interface as an adaptor on +top of `Concrete`. In tests, you can easily mock that interface to observe how +your code is doing. + +This technique incurs some overhead: + +* You pay the cost of virtual function calls (usually not a problem). +* There is more abstraction for the programmers to learn. + +However, it can also bring significant benefits in addition to better +testability: + +* `Concrete`'s API may not fit your problem domain very well, as you may not + be the only client it tries to serve. By designing your own interface, you + have a chance to tailor it to your need - you may add higher-level + functionalities, rename stuff, etc instead of just trimming the class. This + allows you to write your code (user of the interface) in a more natural way, + which means it will be more readable, more maintainable, and you'll be more + productive. +* If `Concrete`'s implementation ever has to change, you don't have to rewrite + everywhere it is used. Instead, you can absorb the change in your + implementation of the interface, and your other code and tests will be + insulated from this change. + +Some people worry that if everyone is practicing this technique, they will end +up writing lots of redundant code. This concern is totally understandable. +However, there are two reasons why it may not be the case: + +* Different projects may need to use `Concrete` in different ways, so the best + interfaces for them will be different. Therefore, each of them will have its + own domain-specific interface on top of `Concrete`, and they will not be the + same code. +* If enough projects want to use the same interface, they can always share it, + just like they have been sharing `Concrete`. You can check in the interface + and the adaptor somewhere near `Concrete` (perhaps in a `contrib` + sub-directory) and let many projects use it. + +You need to weigh the pros and cons carefully for your particular problem, but +I'd like to assure you that the Java community has been practicing this for a +long time and it's a proven effective technique applicable in a wide variety of +situations. :-) + +### Delegating Calls to a Fake {#DelegatingToFake} + +Some times you have a non-trivial fake implementation of an interface. For +example: + +```cpp +class Foo { + public: + virtual ~Foo() {} + virtual char DoThis(int n) = 0; + virtual void DoThat(const char* s, int* p) = 0; +}; + +class FakeFoo : public Foo { + public: + char DoThis(int n) override { + return (n > 0) ? '+' : + (n < 0) ? '-' : '0'; + } + + void DoThat(const char* s, int* p) override { + *p = strlen(s); + } +}; +``` + +Now you want to mock this interface such that you can set expectations on it. +However, you also want to use `FakeFoo` for the default behavior, as duplicating +it in the mock object is, well, a lot of work. + +When you define the mock class using gMock, you can have it delegate its default +action to a fake class you already have, using this pattern: + +```cpp +class MockFoo : public Foo { + public: + // Normal mock method definitions using gMock. + MOCK_METHOD(char, DoThis, (int n), (override)); + MOCK_METHOD(void, DoThat, (const char* s, int* p), (override)); + + // Delegates the default actions of the methods to a FakeFoo object. + // This must be called *before* the custom ON_CALL() statements. + void DelegateToFake() { + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return fake_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + fake_.DoThat(s, p); + }); + } + + private: + FakeFoo fake_; // Keeps an instance of the fake in the mock. +}; +``` + +With that, you can use `MockFoo` in your tests as usual. Just remember that if +you don't explicitly set an action in an `ON_CALL()` or `EXPECT_CALL()`, the +fake will be called upon to do it.: + +```cpp +using ::testing::_; + +TEST(AbcTest, Xyz) { + MockFoo foo; + + foo.DelegateToFake(); // Enables the fake for delegation. + + // Put your ON_CALL(foo, ...)s here, if any. + + // No action specified, meaning to use the default action. + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(foo, DoThat(_, _)); + + int n = 0; + EXPECT_EQ(foo.DoThis(5), '+'); // FakeFoo::DoThis() is invoked. + foo.DoThat("Hi", &n); // FakeFoo::DoThat() is invoked. + EXPECT_EQ(n, 2); +} +``` + +**Some tips:** + +* If you want, you can still override the default action by providing your own + `ON_CALL()` or using `.WillOnce()` / `.WillRepeatedly()` in `EXPECT_CALL()`. +* In `DelegateToFake()`, you only need to delegate the methods whose fake + implementation you intend to use. + +* The general technique discussed here works for overloaded methods, but + you'll need to tell the compiler which version you mean. To disambiguate a + mock function (the one you specify inside the parentheses of `ON_CALL()`), + use [this technique](#SelectOverload); to disambiguate a fake function (the + one you place inside `Invoke()`), use a `static_cast` to specify the + function's type. For instance, if class `Foo` has methods `char DoThis(int + n)` and `bool DoThis(double x) const`, and you want to invoke the latter, + you need to write `Invoke(&fake_, static_cast(&FakeFoo::DoThis))` instead of `Invoke(&fake_, &FakeFoo::DoThis)` + (The strange-looking thing inside the angled brackets of `static_cast` is + the type of a function pointer to the second `DoThis()` method.). + +* Having to mix a mock and a fake is often a sign of something gone wrong. + Perhaps you haven't got used to the interaction-based way of testing yet. Or + perhaps your interface is taking on too many roles and should be split up. + Therefore, **don't abuse this**. We would only recommend to do it as an + intermediate step when you are refactoring your code. + +Regarding the tip on mixing a mock and a fake, here's an example on why it may +be a bad sign: Suppose you have a class `System` for low-level system +operations. In particular, it does file and I/O operations. And suppose you want +to test how your code uses `System` to do I/O, and you just want the file +operations to work normally. If you mock out the entire `System` class, you'll +have to provide a fake implementation for the file operation part, which +suggests that `System` is taking on too many roles. + +Instead, you can define a `FileOps` interface and an `IOOps` interface and split +`System`'s functionalities into the two. Then you can mock `IOOps` without +mocking `FileOps`. + +### Delegating Calls to a Real Object + +When using testing doubles (mocks, fakes, stubs, and etc), sometimes their +behaviors will differ from those of the real objects. This difference could be +either intentional (as in simulating an error such that you can test the error +handling code) or unintentional. If your mocks have different behaviors than the +real objects by mistake, you could end up with code that passes the tests but +fails in production. + +You can use the *delegating-to-real* technique to ensure that your mock has the +same behavior as the real object while retaining the ability to validate calls. +This technique is very similar to the [delegating-to-fake](#DelegatingToFake) +technique, the difference being that we use a real object instead of a fake. +Here's an example: + +```cpp +using ::testing::AtLeast; + +class MockFoo : public Foo { + public: + MockFoo() { + // By default, all calls are delegated to the real object. + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return real_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + real_.DoThat(s, p); + }); + ... + } + MOCK_METHOD(char, DoThis, ...); + MOCK_METHOD(void, DoThat, ...); + ... + private: + Foo real_; +}; + +... + MockFoo mock; + EXPECT_CALL(mock, DoThis()) + .Times(3); + EXPECT_CALL(mock, DoThat("Hi")) + .Times(AtLeast(1)); + ... use mock in test ... +``` + +With this, gMock will verify that your code made the right calls (with the right +arguments, in the right order, called the right number of times, etc), and a +real object will answer the calls (so the behavior will be the same as in +production). This gives you the best of both worlds. + +### Delegating Calls to a Parent Class + +Ideally, you should code to interfaces, whose methods are all pure virtual. In +reality, sometimes you do need to mock a virtual method that is not pure (i.e, +it already has an implementation). For example: + +```cpp +class Foo { + public: + virtual ~Foo(); + + virtual void Pure(int n) = 0; + virtual int Concrete(const char* str) { ... } +}; + +class MockFoo : public Foo { + public: + // Mocking a pure method. + MOCK_METHOD(void, Pure, (int n), (override)); + // Mocking a concrete method. Foo::Concrete() is shadowed. + MOCK_METHOD(int, Concrete, (const char* str), (override)); +}; +``` + +Sometimes you may want to call `Foo::Concrete()` instead of +`MockFoo::Concrete()`. Perhaps you want to do it as part of a stub action, or +perhaps your test doesn't need to mock `Concrete()` at all (but it would be +oh-so painful to have to define a new mock class whenever you don't need to mock +one of its methods). + +You can call `Foo::Concrete()` inside an action by: + +```cpp +... + EXPECT_CALL(foo, Concrete).WillOnce([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +or tell the mock object that you don't want to mock `Concrete()`: + +```cpp +... + ON_CALL(foo, Concrete).WillByDefault([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +(Why don't we just write `{ return foo.Concrete(str); }`? If you do that, +`MockFoo::Concrete()` will be called (and cause an infinite recursion) since +`Foo::Concrete()` is virtual. That's just how C++ works.) + +## Using Matchers + +### Matching Argument Values Exactly + +You can specify exactly which arguments a mock method is expecting: + +```cpp +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(5)) + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", bar)); +``` + +### Using Simple Matchers + +You can use matchers to match arguments that have a certain property: + +```cpp +using ::testing::NotNull; +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(Ge(5))) // The argument must be >= 5. + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", NotNull())); + // The second argument must not be NULL. +``` + +A frequently used matcher is `_`, which matches anything: + +```cpp + EXPECT_CALL(foo, DoThat(_, NotNull())); +``` + +### Combining Matchers {#CombiningMatchers} + +You can build complex matchers from existing ones using `AllOf()`, +`AllOfArray()`, `AnyOf()`, `AnyOfArray()` and `Not()`: + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::HasSubstr; +using ::testing::Ne; +using ::testing::Not; +... + // The argument must be > 5 and != 10. + EXPECT_CALL(foo, DoThis(AllOf(Gt(5), + Ne(10)))); + + // The first argument must not contain sub-string "blah". + EXPECT_CALL(foo, DoThat(Not(HasSubstr("blah")), + NULL)); +``` + +Matchers are function objects, and parametrized matchers can be composed just +like any other function. However because their types can be long and rarely +provide meaningful information, it can be easier to express them with C++14 +generic lambdas to avoid specifying types. For example, + +```cpp +using ::testing::Contains; +using ::testing::Property; + +inline constexpr auto HasFoo = [](const auto& f) { + return Property("foo", &MyClass::foo, Contains(f)); +}; +... + EXPECT_THAT(x, HasFoo("blah")); +``` + +### Casting Matchers {#SafeMatcherCast} + +gMock matchers are statically typed, meaning that the compiler can catch your +mistake if you use a matcher of the wrong type (for example, if you use `Eq(5)` +to match a `string` argument). Good for you! + +Sometimes, however, you know what you're doing and want the compiler to give you +some slack. One example is that you have a matcher for `long` and the argument +you want to match is `int`. While the two types aren't exactly the same, there +is nothing really wrong with using a `Matcher` to match an `int` - after +all, we can first convert the `int` argument to a `long` losslessly before +giving it to the matcher. + +To support this need, gMock gives you the `SafeMatcherCast(m)` function. It +casts a matcher `m` to type `Matcher`. To ensure safety, gMock checks that +(let `U` be the type `m` accepts : + +1. Type `T` can be *implicitly* cast to type `U`; +2. When both `T` and `U` are built-in arithmetic types (`bool`, integers, and + floating-point numbers), the conversion from `T` to `U` is not lossy (in + other words, any value representable by `T` can also be represented by `U`); + and +3. When `U` is a reference, `T` must also be a reference (as the underlying + matcher may be interested in the address of the `U` value). + +The code won't compile if any of these conditions isn't met. + +Here's one example: + +```cpp +using ::testing::SafeMatcherCast; + +// A base class and a child class. +class Base { ... }; +class Derived : public Base { ... }; + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, DoThis, (Derived* derived), (override)); +}; + +... + MockFoo foo; + // m is a Matcher we got from somewhere. + EXPECT_CALL(foo, DoThis(SafeMatcherCast(m))); +``` + +If you find `SafeMatcherCast(m)` too limiting, you can use a similar function +`MatcherCast(m)`. The difference is that `MatcherCast` works as long as you +can `static_cast` type `T` to type `U`. + +`MatcherCast` essentially lets you bypass C++'s type system (`static_cast` isn't +always safe as it could throw away information, for example), so be careful not +to misuse/abuse it. + +### Selecting Between Overloaded Functions {#SelectOverload} + +If you expect an overloaded function to be called, the compiler may need some +help on which overloaded version it is. + +To disambiguate functions overloaded on the const-ness of this object, use the +`Const()` argument wrapper. + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + ... + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; + +... + MockFoo foo; + Bar bar1, bar2; + EXPECT_CALL(foo, GetBar()) // The non-const GetBar(). + .WillOnce(ReturnRef(bar1)); + EXPECT_CALL(Const(foo), GetBar()) // The const GetBar(). + .WillOnce(ReturnRef(bar2)); +``` + +(`Const()` is defined by gMock and returns a `const` reference to its argument.) + +To disambiguate overloaded functions with the same number of arguments but +different argument types, you may need to specify the exact type of a matcher, +either by wrapping your matcher in `Matcher()`, or using a matcher whose +type is fixed (`TypedEq`, `An()`, etc): + +```cpp +using ::testing::An; +using ::testing::Matcher; +using ::testing::TypedEq; + +class MockPrinter : public Printer { + public: + MOCK_METHOD(void, Print, (int n), (override)); + MOCK_METHOD(void, Print, (char c), (override)); +}; + +TEST(PrinterTest, Print) { + MockPrinter printer; + + EXPECT_CALL(printer, Print(An())); // void Print(int); + EXPECT_CALL(printer, Print(Matcher(Lt(5)))); // void Print(int); + EXPECT_CALL(printer, Print(TypedEq('a'))); // void Print(char); + + printer.Print(3); + printer.Print(6); + printer.Print('a'); +} +``` + +### Performing Different Actions Based on the Arguments + +When a mock method is called, the *last* matching expectation that's still +active will be selected (think "newer overrides older"). So, you can make a +method do different things depending on its argument values like this: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... + // The default case. + EXPECT_CALL(foo, DoThis(_)) + .WillRepeatedly(Return('b')); + // The more specific case. + EXPECT_CALL(foo, DoThis(Lt(5))) + .WillRepeatedly(Return('a')); +``` + +Now, if `foo.DoThis()` is called with a value less than 5, `'a'` will be +returned; otherwise `'b'` will be returned. + +### Matching Multiple Arguments as a Whole + +Sometimes it's not enough to match the arguments individually. For example, we +may want to say that the first argument must be less than the second argument. +The `With()` clause allows us to match all arguments of a mock function as a +whole. For example, + +```cpp +using ::testing::_; +using ::testing::Ne; +using ::testing::Lt; +... + EXPECT_CALL(foo, InRange(Ne(0), _)) + .With(Lt()); +``` + +says that the first argument of `InRange()` must not be 0, and must be less than +the second argument. + +The expression inside `With()` must be a matcher of type `Matcher>`, where `A1`, ..., `An` are the types of the function arguments. + +You can also write `AllArgs(m)` instead of `m` inside `.With()`. The two forms +are equivalent, but `.With(AllArgs(Lt()))` is more readable than `.With(Lt())`. + +You can use `Args(m)` to match the `n` selected arguments (as a +tuple) against `m`. For example, + +```cpp +using ::testing::_; +using ::testing::AllOf; +using ::testing::Args; +using ::testing::Lt; +... + EXPECT_CALL(foo, Blah) + .With(AllOf(Args<0, 1>(Lt()), Args<1, 2>(Lt()))); +``` + +says that `Blah` will be called with arguments `x`, `y`, and `z` where `x < y < +z`. Note that in this example, it wasn't necessary to specify the positional +matchers. + +As a convenience and example, gMock provides some matchers for 2-tuples, +including the `Lt()` matcher above. See +[Multi-argument Matchers](reference/matchers.md#MultiArgMatchers) for the +complete list. + +Note that if you want to pass the arguments to a predicate of your own (e.g. +`.With(Args<0, 1>(Truly(&MyPredicate)))`), that predicate MUST be written to +take a `std::tuple` as its argument; gMock will pass the `n` selected arguments +as *one* single tuple to the predicate. + +### Using Matchers as Predicates + +Have you noticed that a matcher is just a fancy predicate that also knows how to +describe itself? Many existing algorithms take predicates as arguments (e.g. +those defined in STL's `` header), and it would be a shame if gMock +matchers were not allowed to participate. + +Luckily, you can use a matcher where a unary predicate functor is expected by +wrapping it inside the `Matches()` function. For example, + +```cpp +#include +#include + +using ::testing::Matches; +using ::testing::Ge; + +vector v; +... +// How many elements in v are >= 10? +const int count = count_if(v.begin(), v.end(), Matches(Ge(10))); +``` + +Since you can build complex matchers from simpler ones easily using gMock, this +gives you a way to conveniently construct composite predicates (doing the same +using STL's `` header is just painful). For example, here's a +predicate that's satisfied by any number that is >= 0, <= 100, and != 50: + +```cpp +using ::testing::AllOf; +using ::testing::Ge; +using ::testing::Le; +using ::testing::Matches; +using ::testing::Ne; +... +Matches(AllOf(Ge(0), Le(100), Ne(50))) +``` + +### Using Matchers in googletest Assertions + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### Using Predicates as Matchers + +gMock provides a set of built-in matchers for matching arguments with expected +values—see the [Matchers Reference](reference/matchers.md) for more information. +In case you find the built-in set lacking, you can use an arbitrary unary +predicate function or functor as a matcher - as long as the predicate accepts a +value of the type you want. You do this by wrapping the predicate inside the +`Truly()` function, for example: + +```cpp +using ::testing::Truly; + +int IsEven(int n) { return (n % 2) == 0 ? 1 : 0; } +... + // Bar() must be called with an even number. + EXPECT_CALL(foo, Bar(Truly(IsEven))); +``` + +Note that the predicate function / functor doesn't have to return `bool`. It +works as long as the return value can be used as the condition in the statement +`if (condition) ...`. + +### Matching Arguments that Are Not Copyable + +When you do an `EXPECT_CALL(mock_obj, Foo(bar))`, gMock saves away a copy of +`bar`. When `Foo()` is called later, gMock compares the argument to `Foo()` with +the saved copy of `bar`. This way, you don't need to worry about `bar` being +modified or destroyed after the `EXPECT_CALL()` is executed. The same is true +when you use matchers like `Eq(bar)`, `Le(bar)`, and so on. + +But what if `bar` cannot be copied (i.e. has no copy constructor)? You could +define your own matcher function or callback and use it with `Truly()`, as the +previous couple of recipes have shown. Or, you may be able to get away from it +if you can guarantee that `bar` won't be changed after the `EXPECT_CALL()` is +executed. Just tell gMock that it should save a reference to `bar`, instead of a +copy of it. Here's how: + +```cpp +using ::testing::Eq; +using ::testing::Lt; +... + // Expects that Foo()'s argument == bar. + EXPECT_CALL(mock_obj, Foo(Eq(std::ref(bar)))); + + // Expects that Foo()'s argument < bar. + EXPECT_CALL(mock_obj, Foo(Lt(std::ref(bar)))); +``` + +Remember: if you do this, don't change `bar` after the `EXPECT_CALL()`, or the +result is undefined. + +### Validating a Member of an Object + +Often a mock function takes a reference to object as an argument. When matching +the argument, you may not want to compare the entire object against a fixed +object, as that may be over-specification. Instead, you may need to validate a +certain member variable or the result of a certain getter method of the object. +You can do this with `Field()` and `Property()`. More specifically, + +```cpp +Field(&Foo::bar, m) +``` + +is a matcher that matches a `Foo` object whose `bar` member variable satisfies +matcher `m`. + +```cpp +Property(&Foo::baz, m) +``` + +is a matcher that matches a `Foo` object whose `baz()` method returns a value +that satisfies matcher `m`. + +For example: + +| Expression | Description | +| :--------------------------- | :--------------------------------------- | +| `Field(&Foo::number, Ge(3))` | Matches `x` where `x.number >= 3`. | +| `Property(&Foo::name, StartsWith("John "))` | Matches `x` where `x.name()` starts with `"John "`. | + +Note that in `Property(&Foo::baz, ...)`, method `baz()` must take no argument +and be declared as `const`. Don't use `Property()` against member functions that +you do not own, because taking addresses of functions is fragile and generally +not part of the contract of the function. + +`Field()` and `Property()` can also match plain pointers to objects. For +instance, + +```cpp +using ::testing::Field; +using ::testing::Ge; +... +Field(&Foo::number, Ge(3)) +``` + +matches a plain pointer `p` where `p->number >= 3`. If `p` is `NULL`, the match +will always fail regardless of the inner matcher. + +What if you want to validate more than one members at the same time? Remember +that there are [`AllOf()` and `AllOfArray()`](#CombiningMatchers). + +Finally `Field()` and `Property()` provide overloads that take the field or +property names as the first argument to include it in the error message. This +can be useful when creating combined matchers. + +```cpp +using ::testing::AllOf; +using ::testing::Field; +using ::testing::Matcher; +using ::testing::SafeMatcherCast; + +Matcher IsFoo(const Foo& foo) { + return AllOf(Field("some_field", &Foo::some_field, foo.some_field), + Field("other_field", &Foo::other_field, foo.other_field), + Field("last_field", &Foo::last_field, foo.last_field)); +} +``` + +### Validating the Value Pointed to by a Pointer Argument + +C++ functions often take pointers as arguments. You can use matchers like +`IsNull()`, `NotNull()`, and other comparison matchers to match a pointer, but +what if you want to make sure the value *pointed to* by the pointer, instead of +the pointer itself, has a certain property? Well, you can use the `Pointee(m)` +matcher. + +`Pointee(m)` matches a pointer if and only if `m` matches the value the pointer +points to. For example: + +```cpp +using ::testing::Ge; +using ::testing::Pointee; +... + EXPECT_CALL(foo, Bar(Pointee(Ge(3)))); +``` + +expects `foo.Bar()` to be called with a pointer that points to a value greater +than or equal to 3. + +One nice thing about `Pointee()` is that it treats a `NULL` pointer as a match +failure, so you can write `Pointee(m)` instead of + +```cpp +using ::testing::AllOf; +using ::testing::NotNull; +using ::testing::Pointee; +... + AllOf(NotNull(), Pointee(m)) +``` + +without worrying that a `NULL` pointer will crash your test. + +Also, did we tell you that `Pointee()` works with both raw pointers **and** +smart pointers (`std::unique_ptr`, `std::shared_ptr`, etc)? + +What if you have a pointer to pointer? You guessed it - you can use nested +`Pointee()` to probe deeper inside the value. For example, +`Pointee(Pointee(Lt(3)))` matches a pointer that points to a pointer that points +to a number less than 3 (what a mouthful...). + +### Defining a Custom Matcher Class {#CustomMatcherClass} + +Most matchers can be simply defined using [the MATCHER* macros](#NewMatchers), +which are terse and flexible, and produce good error messages. However, these +macros are not very explicit about the interfaces they create and are not always +suitable, especially for matchers that will be widely reused. + +For more advanced cases, you may need to define your own matcher class. A custom +matcher allows you to test a specific invariant property of that object. Let's +take a look at how to do so. + +Imagine you have a mock function that takes an object of type `Foo`, which has +an `int bar()` method and an `int baz()` method. You want to constrain that the +argument's `bar()` value plus its `baz()` value is a given number. (This is an +invariant.) Here's how we can write and use a matcher class to do so: + +```cpp +class BarPlusBazEqMatcher { + public: + using is_gtest_matcher = void; + + explicit BarPlusBazEqMatcher(int expected_sum) + : expected_sum_(expected_sum) {} + + bool MatchAndExplain(const Foo& foo, + std::ostream* /* listener */) const { + return (foo.bar() + foo.baz()) == expected_sum_; + } + + void DescribeTo(std::ostream* os) const { + *os << "bar() + baz() equals " << expected_sum_; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "bar() + baz() does not equal " << expected_sum_; + } + private: + const int expected_sum_; +}; + +::testing::Matcher BarPlusBazEq(int expected_sum) { + return BarPlusBazEqMatcher(expected_sum); +} + +... + Foo foo; + EXPECT_THAT(foo, BarPlusBazEq(5))...; +``` + +### Matching Containers + +Sometimes an STL container (e.g. list, vector, map, ...) is passed to a mock +function and you may want to validate it. Since most STL containers support the +`==` operator, you can write `Eq(expected_container)` or simply +`expected_container` to match a container exactly. + +Sometimes, though, you may want to be more flexible (for example, the first +element must be an exact match, but the second element can be any positive +number, and so on). Also, containers used in tests often have a small number of +elements, and having to define the expected container out-of-line is a bit of a +hassle. + +You can use the `ElementsAre()` or `UnorderedElementsAre()` matcher in such +cases: + +```cpp +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Gt; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(ElementsAre(1, Gt(0), _, 5))); +``` + +The above matcher says that the container must have 4 elements, which must be 1, +greater than 0, anything, and 5 respectively. + +If you instead write: + +```cpp +using ::testing::_; +using ::testing::Gt; +using ::testing::UnorderedElementsAre; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(UnorderedElementsAre(1, Gt(0), _, 5))); +``` + +It means that the container must have 4 elements, which (under some permutation) +must be 1, greater than 0, anything, and 5 respectively. + +As an alternative you can place the arguments in a C-style array and use +`ElementsAreArray()` or `UnorderedElementsAreArray()` instead: + +```cpp +using ::testing::ElementsAreArray; +... + // ElementsAreArray accepts an array of element values. + const int expected_vector1[] = {1, 5, 2, 4, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector1))); + + // Or, an array of element matchers. + Matcher expected_vector2[] = {1, Gt(2), _, 3, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector2))); +``` + +In case the array needs to be dynamically created (and therefore the array size +cannot be inferred by the compiler), you can give `ElementsAreArray()` an +additional argument to specify the array size: + +```cpp +using ::testing::ElementsAreArray; +... + int* const expected_vector3 = new int[count]; + ... fill expected_vector3 with values ... + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector3, count))); +``` + +Use `Pair` when comparing maps or other associative containers. + +{% raw %} + +```cpp +using ::testing::UnorderedElementsAre; +using ::testing::Pair; +... + absl::flat_hash_map m = {{"a", 1}, {"b", 2}, {"c", 3}}; + EXPECT_THAT(m, UnorderedElementsAre( + Pair("a", 1), Pair("b", 2), Pair("c", 3))); +``` + +{% endraw %} + +**Tips:** + +* `ElementsAre*()` can be used to match *any* container that implements the + STL iterator pattern (i.e. it has a `const_iterator` type and supports + `begin()/end()`), not just the ones defined in STL. It will even work with + container types yet to be written - as long as they follows the above + pattern. +* You can use nested `ElementsAre*()` to match nested (multi-dimensional) + containers. +* If the container is passed by pointer instead of by reference, just write + `Pointee(ElementsAre*(...))`. +* The order of elements *matters* for `ElementsAre*()`. If you are using it + with containers whose element order are undefined (such as a + `std::unordered_map`) you should use `UnorderedElementsAre`. + +### Sharing Matchers + +Under the hood, a gMock matcher object consists of a pointer to a ref-counted +implementation object. Copying matchers is allowed and very efficient, as only +the pointer is copied. When the last matcher that references the implementation +object dies, the implementation object will be deleted. + +Therefore, if you have some complex matcher that you want to use again and +again, there is no need to build it every time. Just assign it to a matcher +variable and use that variable repeatedly! For example, + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Le; +using ::testing::Matcher; +... + Matcher in_range = AllOf(Gt(5), Le(10)); + ... use in_range as a matcher in multiple EXPECT_CALLs ... +``` + +### Matchers must have no side-effects {#PureMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be *purely functional*: they cannot have +any side effects, and the match result must not depend on anything other than +the matcher's parameters and the value being matched. + +This requirement must be satisfied no matter how a matcher is defined (e.g., if +it is one of the standard matchers, or a custom matcher). In particular, a +matcher can never call a mock function, as that will affect the state of the +mock object and gMock. + +## Setting Expectations + +### Knowing When to Expect {#UseOnCall} + +**`ON_CALL`** is likely the *single most under-utilized construct* in gMock. + +There are basically two constructs for defining the behavior of a mock object: +`ON_CALL` and `EXPECT_CALL`. The difference? `ON_CALL` defines what happens when +a mock method is called, but doesn't imply any expectation on the method +being called. `EXPECT_CALL` not only defines the behavior, but also sets an +expectation that the method will be called with the given arguments, for the +given number of times (and *in the given order* when you specify the order +too). + +Since `EXPECT_CALL` does more, isn't it better than `ON_CALL`? Not really. Every +`EXPECT_CALL` adds a constraint on the behavior of the code under test. Having +more constraints than necessary is *baaad* - even worse than not having enough +constraints. + +This may be counter-intuitive. How could tests that verify more be worse than +tests that verify less? Isn't verification the whole point of tests? + +The answer lies in *what* a test should verify. **A good test verifies the +contract of the code.** If a test over-specifies, it doesn't leave enough +freedom to the implementation. As a result, changing the implementation without +breaking the contract (e.g. refactoring and optimization), which should be +perfectly fine to do, can break such tests. Then you have to spend time fixing +them, only to see them broken again the next time the implementation is changed. + +Keep in mind that one doesn't have to verify more than one property in one test. +In fact, **it's a good style to verify only one thing in one test.** If you do +that, a bug will likely break only one or two tests instead of dozens (which +case would you rather debug?). If you are also in the habit of giving tests +descriptive names that tell what they verify, you can often easily guess what's +wrong just from the test log itself. + +So use `ON_CALL` by default, and only use `EXPECT_CALL` when you actually intend +to verify that the call is made. For example, you may have a bunch of `ON_CALL`s +in your test fixture to set the common mock behavior shared by all tests in the +same group, and write (scarcely) different `EXPECT_CALL`s in different `TEST_F`s +to verify different aspects of the code's behavior. Compared with the style +where each `TEST` has many `EXPECT_CALL`s, this leads to tests that are more +resilient to implementational changes (and thus less likely to require +maintenance) and makes the intent of the tests more obvious (so they are easier +to maintain when you do need to maintain them). + +If you are bothered by the "Uninteresting mock function call" message printed +when a mock method without an `EXPECT_CALL` is called, you may use a `NiceMock` +instead to suppress all such messages for the mock object, or suppress the +message for specific methods by adding `EXPECT_CALL(...).Times(AnyNumber())`. DO +NOT suppress it by blindly adding an `EXPECT_CALL(...)`, or you'll have a test +that's a pain to maintain. + +### Ignoring Uninteresting Calls + +If you are not interested in how a mock method is called, just don't say +anything about it. In this case, if the method is ever called, gMock will +perform its default action to allow the test program to continue. If you are not +happy with the default action taken by gMock, you can override it using +`DefaultValue::Set()` (described [here](#DefaultValue)) or `ON_CALL()`. + +Please note that once you expressed interest in a particular mock method (via +`EXPECT_CALL()`), all invocations to it must match some expectation. If this +function is called but the arguments don't match any `EXPECT_CALL()` statement, +it will be an error. + +### Disallowing Unexpected Calls + +If a mock method shouldn't be called at all, explicitly say so: + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +If some calls to the method are allowed, but the rest are not, just list all the +expected calls: + +```cpp +using ::testing::AnyNumber; +using ::testing::Gt; +... + EXPECT_CALL(foo, Bar(5)); + EXPECT_CALL(foo, Bar(Gt(10))) + .Times(AnyNumber()); +``` + +A call to `foo.Bar()` that doesn't match any of the `EXPECT_CALL()` statements +will be an error. + +### Understanding Uninteresting vs Unexpected Calls {#uninteresting-vs-unexpected} + +*Uninteresting* calls and *unexpected* calls are different concepts in gMock. +*Very* different. + +A call `x.Y(...)` is **uninteresting** if there's *not even a single* +`EXPECT_CALL(x, Y(...))` set. In other words, the test isn't interested in the +`x.Y()` method at all, as evident in that the test doesn't care to say anything +about it. + +A call `x.Y(...)` is **unexpected** if there are *some* `EXPECT_CALL(x, +Y(...))`s set, but none of them matches the call. Put another way, the test is +interested in the `x.Y()` method (therefore it explicitly sets some +`EXPECT_CALL` to verify how it's called); however, the verification fails as the +test doesn't expect this particular call to happen. + +**An unexpected call is always an error,** as the code under test doesn't behave +the way the test expects it to behave. + +**By default, an uninteresting call is not an error,** as it violates no +constraint specified by the test. (gMock's philosophy is that saying nothing +means there is no constraint.) However, it leads to a warning, as it *might* +indicate a problem (e.g. the test author might have forgotten to specify a +constraint). + +In gMock, `NiceMock` and `StrictMock` can be used to make a mock class "nice" or +"strict". How does this affect uninteresting calls and unexpected calls? + +A **nice mock** suppresses uninteresting call *warnings*. It is less chatty than +the default mock, but otherwise is the same. If a test fails with a default +mock, it will also fail using a nice mock instead. And vice versa. Don't expect +making a mock nice to change the test's result. + +A **strict mock** turns uninteresting call warnings into errors. So making a +mock strict may change the test's result. + +Let's look at an example: + +```cpp +TEST(...) { + NiceMock mock_registry; + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); + + // Use mock_registry in code under test. + ... &mock_registry ... +} +``` + +The sole `EXPECT_CALL` here says that all calls to `GetDomainOwner()` must have +`"google.com"` as the argument. If `GetDomainOwner("yahoo.com")` is called, it +will be an unexpected call, and thus an error. *Having a nice mock doesn't +change the severity of an unexpected call.* + +So how do we tell gMock that `GetDomainOwner()` can be called with some other +arguments as well? The standard technique is to add a "catch all" `EXPECT_CALL`: + +```cpp + EXPECT_CALL(mock_registry, GetDomainOwner(_)) + .Times(AnyNumber()); // catches all other calls to this method. + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); +``` + +Remember that `_` is the wildcard matcher that matches anything. With this, if +`GetDomainOwner("google.com")` is called, it will do what the second +`EXPECT_CALL` says; if it is called with a different argument, it will do what +the first `EXPECT_CALL` says. + +Note that the order of the two `EXPECT_CALL`s is important, as a newer +`EXPECT_CALL` takes precedence over an older one. + +For more on uninteresting calls, nice mocks, and strict mocks, read +["The Nice, the Strict, and the Naggy"](#NiceStrictNaggy). + +### Ignoring Uninteresting Arguments {#ParameterlessExpectations} + +If your test doesn't care about the parameters (it only cares about the number +or order of calls), you can often simply omit the parameter list: + +```cpp + // Expect foo.Bar( ... ) twice with any arguments. + EXPECT_CALL(foo, Bar).Times(2); + + // Delegate to the given method whenever the factory is invoked. + ON_CALL(foo_factory, MakeFoo) + .WillByDefault(&BuildFooForTest); +``` + +This functionality is only available when a method is not overloaded; to prevent +unexpected behavior it is a compilation error to try to set an expectation on a +method where the specific overload is ambiguous. You can work around this by +supplying a [simpler mock interface](#SimplerInterfaces) than the mocked class +provides. + +This pattern is also useful when the arguments are interesting, but match logic +is substantially complex. You can leave the argument list unspecified and use +SaveArg actions to [save the values for later verification](#SaveArgVerify). If +you do that, you can easily differentiate calling the method the wrong number of +times from calling it with the wrong arguments. + +### Expecting Ordered Calls {#OrderedCalls} + +Although an `EXPECT_CALL()` statement defined later takes precedence when gMock +tries to match a function call with an expectation, by default calls don't have +to happen in the order `EXPECT_CALL()` statements are written. For example, if +the arguments match the matchers in the second `EXPECT_CALL()`, but not those in +the first and third, then the second expectation will be used. + +If you would rather have all calls occur in the order of the expectations, put +the `EXPECT_CALL()` statements in a block where you define a variable of type +`InSequence`: + +```cpp +using ::testing::_; +using ::testing::InSequence; + + { + InSequence s; + + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(bar, DoThat(_)) + .Times(2); + EXPECT_CALL(foo, DoThis(6)); + } +``` + +In this example, we expect a call to `foo.DoThis(5)`, followed by two calls to +`bar.DoThat()` where the argument can be anything, which are in turn followed by +a call to `foo.DoThis(6)`. If a call occurred out-of-order, gMock will report an +error. + +### Expecting Partially Ordered Calls {#PartialOrder} + +Sometimes requiring everything to occur in a predetermined order can lead to +brittle tests. For example, we may care about `A` occurring before both `B` and +`C`, but aren't interested in the relative order of `B` and `C`. In this case, +the test should reflect our real intent, instead of being overly constraining. + +gMock allows you to impose an arbitrary DAG (directed acyclic graph) on the +calls. One way to express the DAG is to use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) of `EXPECT_CALL`. + +Another way is via the `InSequence()` clause (not the same as the `InSequence` +class), which we borrowed from jMock 2. It's less flexible than `After()`, but +more convenient when you have long chains of sequential calls, as it doesn't +require you to come up with different names for the expectations in the chains. +Here's how it works: + +If we view `EXPECT_CALL()` statements as nodes in a graph, and add an edge from +node A to node B wherever A must occur before B, we can get a DAG. We use the +term "sequence" to mean a directed path in this DAG. Now, if we decompose the +DAG into sequences, we just need to know which sequences each `EXPECT_CALL()` +belongs to in order to be able to reconstruct the original DAG. + +So, to specify the partial order on the expectations we need to do two things: +first to define some `Sequence` objects, and then for each `EXPECT_CALL()` say +which `Sequence` objects it is part of. + +Expectations in the same sequence must occur in the order they are written. For +example, + +```cpp +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(foo, A()) + .InSequence(s1, s2); + EXPECT_CALL(bar, B()) + .InSequence(s1); + EXPECT_CALL(bar, C()) + .InSequence(s2); + EXPECT_CALL(foo, D()) + .InSequence(s2); +``` + +specifies the following DAG (where `s1` is `A -> B`, and `s2` is `A -> C -> D`): + +```text + +---> B + | + A ---| + | + +---> C ---> D +``` + +This means that A must occur before B and C, and C must occur before D. There's +no restriction about the order other than these. + +### Controlling When an Expectation Retires + +When a mock method is called, gMock only considers expectations that are still +active. An expectation is active when created, and becomes inactive (aka +*retires*) when a call that has to occur later has occurred. For example, in + +```cpp +using ::testing::_; +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #1 + .Times(AnyNumber()) + .InSequence(s1, s2); + EXPECT_CALL(log, Log(WARNING, _, "Data set is empty.")) // #2 + .InSequence(s1); + EXPECT_CALL(log, Log(WARNING, _, "User not found.")) // #3 + .InSequence(s2); +``` + +as soon as either #2 or #3 is matched, #1 will retire. If a warning `"File too +large."` is logged after this, it will be an error. + +Note that an expectation doesn't retire automatically when it's saturated. For +example, + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")); // #2 +``` + +says that there will be exactly one warning with the message `"File too +large."`. If the second warning contains this message too, #2 will match again +and result in an upper-bound-violated error. + +If this is not what you want, you can ask an expectation to retire as soon as it +becomes saturated: + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #2 + .RetiresOnSaturation(); +``` + +Here #2 can be used only once, so if you have two warnings with the message +`"File too large."`, the first will match #2 and the second will match #1 - +there will be no error. + +## Using Actions + +### Returning References from Mock Methods + +If a mock function's return type is a reference, you need to use `ReturnRef()` +instead of `Return()` to return a result: + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar&, GetBar, (), (override)); +}; +... + MockFoo foo; + Bar bar; + EXPECT_CALL(foo, GetBar()) + .WillOnce(ReturnRef(bar)); +... +``` + +### Returning Live Values from Mock Methods + +The `Return(x)` action saves a copy of `x` when the action is created, and +always returns the same value whenever it's executed. Sometimes you may want to +instead return the *live* value of `x` (i.e. its value at the time when the +action is *executed*.). Use either `ReturnRef()` or `ReturnPointee()` for this +purpose. + +If the mock function's return type is a reference, you can do it using +`ReturnRef(x)`, as shown in the previous recipe ("Returning References from Mock +Methods"). However, gMock doesn't let you use `ReturnRef()` in a mock function +whose return type is not a reference, as doing that usually indicates a user +error. So, what shall you do? + +Though you may be tempted, DO NOT use `std::ref()`: + +```cpp +using ::testing::Return; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetValue, (), (override)); +}; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(Return(std::ref(x))); // Wrong! + x = 42; + EXPECT_EQ(foo.GetValue(), 42); +``` + +Unfortunately, it doesn't work here. The above code will fail with error: + +```text +Value of: foo.GetValue() + Actual: 0 +Expected: 42 +``` + +The reason is that `Return(*value*)` converts `value` to the actual return type +of the mock function at the time when the action is *created*, not when it is +*executed*. (This behavior was chosen for the action to be safe when `value` is +a proxy object that references some temporary objects.) As a result, +`std::ref(x)` is converted to an `int` value (instead of a `const int&`) when +the expectation is set, and `Return(std::ref(x))` will always return 0. + +`ReturnPointee(pointer)` was provided to solve this problem specifically. It +returns the value pointed to by `pointer` at the time the action is *executed*: + +```cpp +using ::testing::ReturnPointee; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(ReturnPointee(&x)); // Note the & here. + x = 42; + EXPECT_EQ(foo.GetValue(), 42); // This will succeed now. +``` + +### Combining Actions + +Want to do more than one thing when a function is called? That's fine. `DoAll()` +allows you to do a sequence of actions every time. Only the return value of the +last action in the sequence will be used. + +```cpp +using ::testing::_; +using ::testing::DoAll; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, Bar, (int n), (override)); +}; +... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(DoAll(action_1, + action_2, + ... + action_n)); +``` + +### Verifying Complex Arguments {#SaveArgVerify} + +If you want to verify that a method is called with a particular argument but the +match criteria is complex, it can be difficult to distinguish between +cardinality failures (calling the method the wrong number of times) and argument +match failures. Similarly, if you are matching multiple parameters, it may not +be easy to distinguishing which argument failed to match. For example: + +```cpp + // Not ideal: this could fail because of a problem with arg1 or arg2, or maybe + // just the method wasn't called. + EXPECT_CALL(foo, SendValues(_, ElementsAre(1, 4, 4, 7), EqualsProto( ... ))); +``` + +You can instead save the arguments and test them individually: + +```cpp + EXPECT_CALL(foo, SendValues) + .WillOnce(DoAll(SaveArg<1>(&actual_array), SaveArg<2>(&actual_proto))); + ... run the test + EXPECT_THAT(actual_array, ElementsAre(1, 4, 4, 7)); + EXPECT_THAT(actual_proto, EqualsProto( ... )); +``` + +### Mocking Side Effects {#MockingSideEffects} + +Sometimes a method exhibits its effect not via returning a value but via side +effects. For example, it may change some global state or modify an output +argument. To mock side effects, in general you can define your own action by +implementing `::testing::ActionInterface`. + +If all you need to do is to change an output argument, the built-in +`SetArgPointee()` action is convenient: + +```cpp +using ::testing::_; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + MOCK_METHOD(void, Mutate, (bool mutate, int* value), (override)); + ... +} +... + MockMutator mutator; + EXPECT_CALL(mutator, Mutate(true, _)) + .WillOnce(SetArgPointee<1>(5)); +``` + +In this example, when `mutator.Mutate()` is called, we will assign 5 to the +`int` variable pointed to by argument #1 (0-based). + +`SetArgPointee()` conveniently makes an internal copy of the value you pass to +it, removing the need to keep the value in scope and alive. The implication +however is that the value must have a copy constructor and assignment operator. + +If the mock method also needs to return a value as well, you can chain +`SetArgPointee()` with `Return()` using `DoAll()`, remembering to put the +`Return()` statement last: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + ... + MOCK_METHOD(bool, MutateInt, (int* value), (override)); +} +... + MockMutator mutator; + EXPECT_CALL(mutator, MutateInt(_)) + .WillOnce(DoAll(SetArgPointee<0>(5), + Return(true))); +``` + +Note, however, that if you use the `ReturnOKWith()` method, it will override the +values provided by `SetArgPointee()` in the response parameters of your function +call. + +If the output argument is an array, use the `SetArrayArgument(first, last)` +action instead. It copies the elements in source range `[first, last)` to the +array pointed to by the `N`-th (0-based) argument: + +```cpp +using ::testing::NotNull; +using ::testing::SetArrayArgument; + +class MockArrayMutator : public ArrayMutator { + public: + MOCK_METHOD(void, Mutate, (int* values, int num_values), (override)); + ... +} +... + MockArrayMutator mutator; + int values[5] = {1, 2, 3, 4, 5}; + EXPECT_CALL(mutator, Mutate(NotNull(), 5)) + .WillOnce(SetArrayArgument<0>(values, values + 5)); +``` + +This also works when the argument is an output iterator: + +```cpp +using ::testing::_; +using ::testing::SetArrayArgument; + +class MockRolodex : public Rolodex { + public: + MOCK_METHOD(void, GetNames, (std::back_insert_iterator>), + (override)); + ... +} +... + MockRolodex rolodex; + vector names = {"George", "John", "Thomas"}; + EXPECT_CALL(rolodex, GetNames(_)) + .WillOnce(SetArrayArgument<0>(names.begin(), names.end())); +``` + +### Changing a Mock Object's Behavior Based on the State + +If you expect a call to change the behavior of a mock object, you can use +`::testing::InSequence` to specify different behaviors before and after the +call: + +```cpp +using ::testing::InSequence; +using ::testing::Return; + +... + { + InSequence seq; + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(my_mock, Flush()); + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(false)); + } + my_mock.FlushIfDirty(); +``` + +This makes `my_mock.IsDirty()` return `true` before `my_mock.Flush()` is called +and return `false` afterwards. + +If the behavior change is more complex, you can store the effects in a variable +and make a mock method get its return value from that variable: + +```cpp +using ::testing::_; +using ::testing::SaveArg; +using ::testing::Return; + +ACTION_P(ReturnPointee, p) { return *p; } +... + int previous_value = 0; + EXPECT_CALL(my_mock, GetPrevValue) + .WillRepeatedly(ReturnPointee(&previous_value)); + EXPECT_CALL(my_mock, UpdateValue) + .WillRepeatedly(SaveArg<0>(&previous_value)); + my_mock.DoSomethingToUpdateValue(); +``` + +Here `my_mock.GetPrevValue()` will always return the argument of the last +`UpdateValue()` call. + +### Setting the Default Value for a Return Type {#DefaultValue} + +If a mock method's return type is a built-in C++ type or pointer, by default it +will return 0 when invoked. Also, in C++ 11 and above, a mock method whose +return type has a default constructor will return a default-constructed value by +default. You only need to specify an action if this default value doesn't work +for you. + +Sometimes, you may want to change this default value, or you may want to specify +a default value for types gMock doesn't know about. You can do this using the +`::testing::DefaultValue` class template: + +```cpp +using ::testing::DefaultValue; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar, CalculateBar, (), (override)); +}; + + +... + Bar default_bar; + // Sets the default return value for type Bar. + DefaultValue::Set(default_bar); + + MockFoo foo; + + // We don't need to specify an action here, as the default + // return value works for us. + EXPECT_CALL(foo, CalculateBar()); + + foo.CalculateBar(); // This should return default_bar. + + // Unsets the default return value. + DefaultValue::Clear(); +``` + +Please note that changing the default value for a type can make your tests hard +to understand. We recommend you to use this feature judiciously. For example, +you may want to make sure the `Set()` and `Clear()` calls are right next to the +code that uses your mock. + +### Setting the Default Actions for a Mock Method + +You've learned how to change the default value of a given type. However, this +may be too coarse for your purpose: perhaps you have two mock methods with the +same return type and you want them to have different behaviors. The `ON_CALL()` +macro allows you to customize your mock's behavior at the method level: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Gt; +using ::testing::Return; +... + ON_CALL(foo, Sign(_)) + .WillByDefault(Return(-1)); + ON_CALL(foo, Sign(0)) + .WillByDefault(Return(0)); + ON_CALL(foo, Sign(Gt(0))) + .WillByDefault(Return(1)); + + EXPECT_CALL(foo, Sign(_)) + .Times(AnyNumber()); + + foo.Sign(5); // This should return 1. + foo.Sign(-9); // This should return -1. + foo.Sign(0); // This should return 0. +``` + +As you may have guessed, when there are more than one `ON_CALL()` statements, +the newer ones in the order take precedence over the older ones. In other words, +the **last** one that matches the function arguments will be used. This matching +order allows you to set up the common behavior in a mock object's constructor or +the test fixture's set-up phase and specialize the mock's behavior later. + +Note that both `ON_CALL` and `EXPECT_CALL` have the same "later statements take +precedence" rule, but they don't interact. That is, `EXPECT_CALL`s have their +own precedence order distinct from the `ON_CALL` precedence order. + +### Using Functions/Methods/Functors/Lambdas as Actions {#FunctionsAsActions} + +If the built-in actions don't suit you, you can use an existing callable +(function, `std::function`, method, functor, lambda) as an action. + +```cpp +using ::testing::_; using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, Sum, (int x, int y), (override)); + MOCK_METHOD(bool, ComplexJob, (int x), (override)); +}; + +int CalculateSum(int x, int y) { return x + y; } +int Sum3(int x, int y, int z) { return x + y + z; } + +class Helper { + public: + bool ComplexJob(int x); +}; + +... + MockFoo foo; + Helper helper; + EXPECT_CALL(foo, Sum(_, _)) + .WillOnce(&CalculateSum) + .WillRepeatedly(Invoke(NewPermanentCallback(Sum3, 1))); + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce(Invoke(&helper, &Helper::ComplexJob)) + .WillOnce([] { return true; }) + .WillRepeatedly([](int x) { return x > 0; }); + + foo.Sum(5, 6); // Invokes CalculateSum(5, 6). + foo.Sum(2, 3); // Invokes Sum3(1, 2, 3). + foo.ComplexJob(10); // Invokes helper.ComplexJob(10). + foo.ComplexJob(-1); // Invokes the inline lambda. +``` + +The only requirement is that the type of the function, etc must be *compatible* +with the signature of the mock function, meaning that the latter's arguments (if +it takes any) can be implicitly converted to the corresponding arguments of the +former, and the former's return type can be implicitly converted to that of the +latter. So, you can invoke something whose type is *not* exactly the same as the +mock function, as long as it's safe to do so - nice, huh? + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::Invoke; + ... + ResultCallback* is_ok = ...; + ... Invoke(is_ok) ...; // This works. + + BlockingClosure* done = new BlockingClosure; + ... Invoke(implicit_cast(done)) ...; // The cast is necessary. + ``` + +### Using Functions with Extra Info as Actions + +The function or functor you call using `Invoke()` must have the same number of +arguments as the mock function you use it for. Sometimes you may have a function +that takes more arguments, and you are willing to pass in the extra arguments +yourself to fill the gap. You can do this in gMock using callbacks with +pre-bound arguments. Here's an example: + +```cpp +using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(char, DoThis, (int n), (override)); +}; + +char SignOfSum(int x, int y) { + const int sum = x + y; + return (sum > 0) ? '+' : (sum < 0) ? '-' : '0'; +} + +TEST_F(FooTest, Test) { + MockFoo foo; + + EXPECT_CALL(foo, DoThis(2)) + .WillOnce(Invoke(NewPermanentCallback(SignOfSum, 5))); + EXPECT_EQ(foo.DoThis(2), '+'); // Invokes SignOfSum(5, 2). +} +``` + +### Invoking a Function/Method/Functor/Lambda/Callback Without Arguments + +`Invoke()` passes the mock function's arguments to the function, etc being +invoked such that the callee has the full context of the call to work with. If +the invoked function is not interested in some or all of the arguments, it can +simply ignore them. + +Yet, a common pattern is that a test author wants to invoke a function without +the arguments of the mock function. She could do that using a wrapper function +that throws away the arguments before invoking an underlining nullary function. +Needless to say, this can be tedious and obscures the intent of the test. + +There are two solutions to this problem. First, you can pass any callable of +zero args as an action. Alternatively, use `InvokeWithoutArgs()`, which is like +`Invoke()` except that it doesn't pass the mock function's arguments to the +callee. Here's an example of each: + +```cpp +using ::testing::_; +using ::testing::InvokeWithoutArgs; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, ComplexJob, (int n), (override)); +}; + +bool Job1() { ... } +bool Job2(int n, char c) { ... } + +... + MockFoo foo; + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce([] { Job1(); }); + .WillOnce(InvokeWithoutArgs(NewPermanentCallback(Job2, 5, 'a'))); + + foo.ComplexJob(10); // Invokes Job1(). + foo.ComplexJob(20); // Invokes Job2(5, 'a'). +``` + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::InvokeWithoutArgs; + ... + ResultCallback* is_ok = ...; + ... InvokeWithoutArgs(is_ok) ...; // This works. + + BlockingClosure* done = ...; + ... InvokeWithoutArgs(implicit_cast(done)) ...; + // The cast is necessary. + ``` + +### Invoking an Argument of the Mock Function + +Sometimes a mock function will receive a function pointer, a functor (in other +words, a "callable") as an argument, e.g. + +```cpp +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, DoThis, (int n, (ResultCallback1* callback)), + (override)); +}; +``` + +and you may want to invoke this callable argument: + +```cpp +using ::testing::_; +... + MockFoo foo; + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(...); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +{: .callout .note} +NOTE: The section below is legacy documentation from before C++ had lambdas: + +Arghh, you need to refer to a mock function argument but C++ has no lambda +(yet), so you have to define your own action. :-( Or do you really? + +Well, gMock has an action to solve *exactly* this problem: + +```cpp +InvokeArgument(arg_1, arg_2, ..., arg_m) +``` + +will invoke the `N`-th (0-based) argument the mock function receives, with +`arg_1`, `arg_2`, ..., and `arg_m`. No matter if the argument is a function +pointer, a functor, or a callback. gMock handles them all. + +With that, you could write: + +```cpp +using ::testing::_; +using ::testing::InvokeArgument; +... + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(InvokeArgument<1>(5)); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +What if the callable takes an argument by reference? No problem - just wrap it +inside `std::ref()`: + +```cpp + ... + MOCK_METHOD(bool, Bar, + ((ResultCallback2* callback)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + Helper helper; + ... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(InvokeArgument<0>(5, std::ref(helper))); + // std::ref(helper) guarantees that a reference to helper, not a copy of + // it, will be passed to the callback. +``` + +What if the callable takes an argument by reference and we do **not** wrap the +argument in `std::ref()`? Then `InvokeArgument()` will *make a copy* of the +argument, and pass a *reference to the copy*, instead of a reference to the +original value, to the callable. This is especially handy when the argument is a +temporary value: + +```cpp + ... + MOCK_METHOD(bool, DoThat, (bool (*f)(const double& x, const string& s)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + ... + EXPECT_CALL(foo, DoThat(_)) + .WillOnce(InvokeArgument<0>(5.0, string("Hi"))); + // Will execute (*f)(5.0, string("Hi")), where f is the function pointer + // DoThat() receives. Note that the values 5.0 and string("Hi") are + // temporary and dead once the EXPECT_CALL() statement finishes. Yet + // it's fine to perform this action later, since a copy of the values + // are kept inside the InvokeArgument action. +``` + +### Ignoring an Action's Result + +Sometimes you have an action that returns *something*, but you need an action +that returns `void` (perhaps you want to use it in a mock function that returns +`void`, or perhaps it needs to be used in `DoAll()` and it's not the last in the +list). `IgnoreResult()` lets you do that. For example: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::IgnoreResult; +using ::testing::Return; + +int Process(const MyData& data); +string DoSomething(); + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, Abc, (const MyData& data), (override)); + MOCK_METHOD(bool, Xyz, (), (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, Abc(_)) + // .WillOnce(Invoke(Process)); + // The above line won't compile as Process() returns int but Abc() needs + // to return void. + .WillOnce(IgnoreResult(Process)); + EXPECT_CALL(foo, Xyz()) + .WillOnce(DoAll(IgnoreResult(DoSomething), + // Ignores the string DoSomething() returns. + Return(true))); +``` + +Note that you **cannot** use `IgnoreResult()` on an action that already returns +`void`. Doing so will lead to ugly compiler errors. + +### Selecting an Action's Arguments {#SelectingArgs} + +Say you have a mock function `Foo()` that takes seven arguments, and you have a +custom action that you want to invoke when `Foo()` is called. Trouble is, the +custom action only wants three arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... + MOCK_METHOD(bool, Foo, + (bool visible, const string& name, int x, int y, + (const map>), double& weight, double min_weight, + double max_wight)); +... +bool IsVisibleInQuadrant1(bool visible, int x, int y) { + return visible && x >= 0 && y >= 0; +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(IsVisibleInQuadrant1)); // Uh, won't compile. :-( +``` + +To please the compiler God, you need to define an "adaptor" that has the same +signature as `Foo()` and calls the custom action with the right arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... +bool MyIsVisibleInQuadrant1(bool visible, const string& name, int x, int y, + const map, double>& weight, + double min_weight, double max_wight) { + return IsVisibleInQuadrant1(visible, x, y); +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(MyIsVisibleInQuadrant1)); // Now it works. +``` + +But isn't this awkward? + +gMock provides a generic *action adaptor*, so you can spend your time minding +more important business than writing your own adaptors. Here's the syntax: + +```cpp +WithArgs(action) +``` + +creates an action that passes the arguments of the mock function at the given +indices (0-based) to the inner `action` and performs it. Using `WithArgs`, our +original example can be written as: + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::WithArgs; +... + EXPECT_CALL(mock, Foo) + .WillOnce(WithArgs<0, 2, 3>(Invoke(IsVisibleInQuadrant1))); // No need to define your own adaptor. +``` + +For better readability, gMock also gives you: + +* `WithoutArgs(action)` when the inner `action` takes *no* argument, and +* `WithArg(action)` (no `s` after `Arg`) when the inner `action` takes + *one* argument. + +As you may have realized, `InvokeWithoutArgs(...)` is just syntactic sugar for +`WithoutArgs(Invoke(...))`. + +Here are more tips: + +* The inner action used in `WithArgs` and friends does not have to be + `Invoke()` -- it can be anything. +* You can repeat an argument in the argument list if necessary, e.g. + `WithArgs<2, 3, 3, 5>(...)`. +* You can change the order of the arguments, e.g. `WithArgs<3, 2, 1>(...)`. +* The types of the selected arguments do *not* have to match the signature of + the inner action exactly. It works as long as they can be implicitly + converted to the corresponding arguments of the inner action. For example, + if the 4-th argument of the mock function is an `int` and `my_action` takes + a `double`, `WithArg<4>(my_action)` will work. + +### Ignoring Arguments in Action Functions + +The [selecting-an-action's-arguments](#SelectingArgs) recipe showed us one way +to make a mock function and an action with incompatible argument lists fit +together. The downside is that wrapping the action in `WithArgs<...>()` can get +tedious for people writing the tests. + +If you are defining a function (or method, functor, lambda, callback) to be used +with `Invoke*()`, and you are not interested in some of its arguments, an +alternative to `WithArgs` is to declare the uninteresting arguments as `Unused`. +This makes the definition less cluttered and less fragile in case the types of +the uninteresting arguments change. It could also increase the chance the action +function can be reused. For example, given + +```cpp + public: + MOCK_METHOD(double, Foo, double(const string& label, double x, double y), + (override)); + MOCK_METHOD(double, Bar, (int index, double x, double y), (override)); +``` + +instead of + +```cpp +using ::testing::_; +using ::testing::Invoke; + +double DistanceToOriginWithLabel(const string& label, double x, double y) { + return sqrt(x*x + y*y); +} +double DistanceToOriginWithIndex(int index, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOriginWithLabel)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOriginWithIndex)); +``` + +you could write + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::Unused; + +double DistanceToOrigin(Unused, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOrigin)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOrigin)); +``` + +### Sharing Actions + +Just like matchers, a gMock action object consists of a pointer to a ref-counted +implementation object. Therefore copying actions is also allowed and very +efficient. When the last action that references the implementation object dies, +the implementation object will be deleted. + +If you have some complex action that you want to use again and again, you may +not have to build it from scratch every time. If the action doesn't have an +internal state (i.e. if it always does the same thing no matter how many times +it has been called), you can assign it to an action variable and use that +variable repeatedly. For example: + +```cpp +using ::testing::Action; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; +... + Action set_flag = DoAll(SetArgPointee<0>(5), + Return(true)); + ... use set_flag in .WillOnce() and .WillRepeatedly() ... +``` + +However, if the action has its own state, you may be surprised if you share the +action object. Suppose you have an action factory `IncrementCounter(init)` which +creates an action that increments and returns a counter whose initial value is +`init`, using two actions created from the same expression and using a shared +action will exhibit different behaviors. Example: + +```cpp + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(IncrementCounter(0)); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(IncrementCounter(0)); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 1 - DoThat() uses a different + // counter than DoThis()'s. +``` + +versus + +```cpp +using ::testing::Action; +... + Action increment = IncrementCounter(0); + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(increment); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(increment); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 3 - the counter is shared. +``` + +### Testing Asynchronous Behavior + +One oft-encountered problem with gMock is that it can be hard to test +asynchronous behavior. Suppose you had a `EventQueue` class that you wanted to +test, and you created a separate `EventDispatcher` interface so that you could +easily mock it out. However, the implementation of the class fired all the +events on a background thread, which made test timings difficult. You could just +insert `sleep()` statements and hope for the best, but that makes your test +behavior nondeterministic. A better way is to use gMock actions and +`Notification` objects to force your asynchronous test to behave synchronously. + +```cpp +class MockEventDispatcher : public EventDispatcher { + MOCK_METHOD(bool, DispatchEvent, (int32), (override)); +}; + +TEST(EventQueueTest, EnqueueEventTest) { + MockEventDispatcher mock_event_dispatcher; + EventQueue event_queue(&mock_event_dispatcher); + + const int32 kEventId = 321; + absl::Notification done; + EXPECT_CALL(mock_event_dispatcher, DispatchEvent(kEventId)) + .WillOnce([&done] { done.Notify(); }); + + event_queue.EnqueueEvent(kEventId); + done.WaitForNotification(); +} +``` + +In the example above, we set our normal gMock expectations, but then add an +additional action to notify the `Notification` object. Now we can just call +`Notification::WaitForNotification()` in the main thread to wait for the +asynchronous call to finish. After that, our test suite is complete and we can +safely exit. + +{: .callout .note} +Note: this example has a downside: namely, if the expectation is not satisfied, +our test will run forever. It will eventually time-out and fail, but it will +take longer and be slightly harder to debug. To alleviate this problem, you can +use `WaitForNotificationWithTimeout(ms)` instead of `WaitForNotification()`. + +## Misc Recipes on Using gMock + +### Mocking Methods That Use Move-Only Types + +C++11 introduced *move-only types*. A move-only-typed value can be moved from +one object to another, but cannot be copied. `std::unique_ptr` is probably +the most commonly used move-only type. + +Mocking a method that takes and/or returns move-only types presents some +challenges, but nothing insurmountable. This recipe shows you how you can do it. +Note that the support for move-only method arguments was only introduced to +gMock in April 2017; in older code, you may find more complex +[workarounds](#LegacyMoveOnly) for lack of this feature. + +Let’s say we are working on a fictional project that lets one post and share +snippets called “buzzes”. Your code uses these types: + +```cpp +enum class AccessLevel { kInternal, kPublic }; + +class Buzz { + public: + explicit Buzz(AccessLevel access) { ... } + ... +}; + +class Buzzer { + public: + virtual ~Buzzer() {} + virtual std::unique_ptr MakeBuzz(StringPiece text) = 0; + virtual bool ShareBuzz(std::unique_ptr buzz, int64_t timestamp) = 0; + ... +}; +``` + +A `Buzz` object represents a snippet being posted. A class that implements the +`Buzzer` interface is capable of creating and sharing `Buzz`es. Methods in +`Buzzer` may return a `unique_ptr` or take a `unique_ptr`. Now we +need to mock `Buzzer` in our tests. + +To mock a method that accepts or returns move-only types, you just use the +familiar `MOCK_METHOD` syntax as usual: + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(std::unique_ptr, MakeBuzz, (StringPiece text), (override)); + MOCK_METHOD(bool, ShareBuzz, (std::unique_ptr buzz, int64_t timestamp), + (override)); +}; +``` + +Now that we have the mock class defined, we can use it in tests. In the +following code examples, we assume that we have defined a `MockBuzzer` object +named `mock_buzzer_`: + +```cpp + MockBuzzer mock_buzzer_; +``` + +First let’s see how we can set expectations on the `MakeBuzz()` method, which +returns a `unique_ptr`. + +As usual, if you set an expectation without an action (i.e. the `.WillOnce()` or +`.WillRepeatedly()` clause), when that expectation fires, the default action for +that method will be taken. Since `unique_ptr<>` has a default constructor that +returns a null `unique_ptr`, that’s what you’ll get if you don’t specify an +action: + +```cpp +using ::testing::IsNull; +... + // Use the default action. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")); + + // Triggers the previous EXPECT_CALL. + EXPECT_THAT(mock_buzzer_.MakeBuzz("hello"), IsNull()); +``` + +If you are not happy with the default action, you can tweak it as usual; see +[Setting Default Actions](#OnCall). + +If you just need to return a move-only value, you can use it in combination with +`WillOnce`. For example: + +```cpp + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")) + .WillOnce(Return(std::make_unique(AccessLevel::kInternal))); + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("hello")); +``` + +Quiz time! What do you think will happen if a `Return` action is performed more +than once (e.g. you write `... .WillRepeatedly(Return(std::move(...)));`)? Come +think of it, after the first time the action runs, the source value will be +consumed (since it’s a move-only value), so the next time around, there’s no +value to move from -- you’ll get a run-time error that `Return(std::move(...))` +can only be run once. + +If you need your mock method to do more than just moving a pre-defined value, +remember that you can always use a lambda or a callable object, which can do +pretty much anything you want: + +```cpp + EXPECT_CALL(mock_buzzer_, MakeBuzz("x")) + .WillRepeatedly([](StringPiece text) { + return std::make_unique(AccessLevel::kInternal); + }); + + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); +``` + +Every time this `EXPECT_CALL` fires, a new `unique_ptr` will be created +and returned. You cannot do this with `Return(std::make_unique<...>(...))`. + +That covers returning move-only values; but how do we work with methods +accepting move-only arguments? The answer is that they work normally, although +some actions will not compile when any of method's arguments are move-only. You +can always use `Return`, or a [lambda or functor](#FunctionsAsActions): + +```cpp + using ::testing::Unused; + + EXPECT_CALL(mock_buzzer_, ShareBuzz(NotNull(), _)).WillOnce(Return(true)); + EXPECT_TRUE(mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal)), + 0); + + EXPECT_CALL(mock_buzzer_, ShareBuzz(_, _)).WillOnce( + [](std::unique_ptr buzz, Unused) { return buzz != nullptr; }); + EXPECT_FALSE(mock_buzzer_.ShareBuzz(nullptr, 0)); +``` + +Many built-in actions (`WithArgs`, `WithoutArgs`,`DeleteArg`, `SaveArg`, ...) +could in principle support move-only arguments, but the support for this is not +implemented yet. If this is blocking you, please file a bug. + +A few actions (e.g. `DoAll`) copy their arguments internally, so they can never +work with non-copyable objects; you'll have to use functors instead. + +#### Legacy workarounds for move-only types {#LegacyMoveOnly} + +Support for move-only function arguments was only introduced to gMock in April +of 2017. In older code, you may encounter the following workaround for the lack +of this feature (it is no longer necessary - we're including it just for +reference): + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(bool, DoShareBuzz, (Buzz* buzz, Time timestamp)); + bool ShareBuzz(std::unique_ptr buzz, Time timestamp) override { + return DoShareBuzz(buzz.get(), timestamp); + } +}; +``` + +The trick is to delegate the `ShareBuzz()` method to a mock method (let’s call +it `DoShareBuzz()`) that does not take move-only parameters. Then, instead of +setting expectations on `ShareBuzz()`, you set them on the `DoShareBuzz()` mock +method: + +```cpp + MockBuzzer mock_buzzer_; + EXPECT_CALL(mock_buzzer_, DoShareBuzz(NotNull(), _)); + + // When one calls ShareBuzz() on the MockBuzzer like this, the call is + // forwarded to DoShareBuzz(), which is mocked. Therefore this statement + // will trigger the above EXPECT_CALL. + mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal), 0); +``` + +### Making the Compilation Faster + +Believe it or not, the *vast majority* of the time spent on compiling a mock +class is in generating its constructor and destructor, as they perform +non-trivial tasks (e.g. verification of the expectations). What's more, mock +methods with different signatures have different types and thus their +constructors/destructors need to be generated by the compiler separately. As a +result, if you mock many different types of methods, compiling your mock class +can get really slow. + +If you are experiencing slow compilation, you can move the definition of your +mock class' constructor and destructor out of the class body and into a `.cc` +file. This way, even if you `#include` your mock class in N files, the compiler +only needs to generate its constructor and destructor once, resulting in a much +faster compilation. + +Let's illustrate the idea using an example. Here's the definition of a mock +class before applying this recipe: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // Since we don't declare the constructor or the destructor, + // the compiler will generate them in every translation unit + // where this mock class is used. + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +After the change, it would look like: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // The constructor and destructor are declared, but not defined, here. + MockFoo(); + virtual ~MockFoo(); + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +and + +```cpp +// File mock_foo.cc. +#include "path/to/mock_foo.h" + +// The definitions may appear trivial, but the functions actually do a +// lot of things through the constructors/destructors of the member +// variables used to implement the mock methods. +MockFoo::MockFoo() {} +MockFoo::~MockFoo() {} +``` + +### Forcing a Verification + +When it's being destroyed, your friendly mock object will automatically verify +that all expectations on it have been satisfied, and will generate googletest +failures if not. This is convenient as it leaves you with one less thing to +worry about. That is, unless you are not sure if your mock object will be +destroyed. + +How could it be that your mock object won't eventually be destroyed? Well, it +might be created on the heap and owned by the code you are testing. Suppose +there's a bug in that code and it doesn't delete the mock object properly - you +could end up with a passing test when there's actually a bug. + +Using a heap checker is a good idea and can alleviate the concern, but its +implementation is not 100% reliable. So, sometimes you do want to *force* gMock +to verify a mock object before it is (hopefully) destructed. You can do this +with `Mock::VerifyAndClearExpectations(&mock_object)`: + +```cpp +TEST(MyServerTest, ProcessesRequest) { + using ::testing::Mock; + + MockFoo* const foo = new MockFoo; + EXPECT_CALL(*foo, ...)...; + // ... other expectations ... + + // server now owns foo. + MyServer server(foo); + server.ProcessRequest(...); + + // In case that server's destructor will forget to delete foo, + // this will verify the expectations anyway. + Mock::VerifyAndClearExpectations(foo); +} // server is destroyed when it goes out of scope here. +``` + +{: .callout .tip} +**Tip:** The `Mock::VerifyAndClearExpectations()` function returns a `bool` to +indicate whether the verification was successful (`true` for yes), so you can +wrap that function call inside a `ASSERT_TRUE()` if there is no point going +further when the verification has failed. + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +### Using Checkpoints {#UsingCheckPoints} + +Sometimes you might want to test a mock object's behavior in phases whose sizes +are each manageable, or you might want to set more detailed expectations about +which API calls invoke which mock functions. + +A technique you can use is to put the expectations in a sequence and insert +calls to a dummy "checkpoint" function at specific places. Then you can verify +that the mock function calls do happen at the right time. For example, if you +are exercising the code: + +```cpp + Foo(1); + Foo(2); + Foo(3); +``` + +and want to verify that `Foo(1)` and `Foo(3)` both invoke `mock.Bar("a")`, but +`Foo(2)` doesn't invoke anything, you can write: + +```cpp +using ::testing::MockFunction; + +TEST(FooTest, InvokesBarCorrectly) { + MyMock mock; + // Class MockFunction has exactly one mock method. It is named + // Call() and has type F. + MockFunction check; + { + InSequence s; + + EXPECT_CALL(mock, Bar("a")); + EXPECT_CALL(check, Call("1")); + EXPECT_CALL(check, Call("2")); + EXPECT_CALL(mock, Bar("a")); + } + Foo(1); + check.Call("1"); + Foo(2); + check.Call("2"); + Foo(3); +} +``` + +The expectation spec says that the first `Bar("a")` call must happen before +checkpoint "1", the second `Bar("a")` call must happen after checkpoint "2", and +nothing should happen between the two checkpoints. The explicit checkpoints make +it clear which `Bar("a")` is called by which call to `Foo()`. + +### Mocking Destructors + +Sometimes you want to make sure a mock object is destructed at the right time, +e.g. after `bar->A()` is called but before `bar->B()` is called. We already know +that you can specify constraints on the [order](#OrderedCalls) of mock function +calls, so all we need to do is to mock the destructor of the mock function. + +This sounds simple, except for one problem: a destructor is a special function +with special syntax and special semantics, and the `MOCK_METHOD` macro doesn't +work for it: + +```cpp +MOCK_METHOD(void, ~MockFoo, ()); // Won't compile! +``` + +The good news is that you can use a simple pattern to achieve the same effect. +First, add a mock function `Die()` to your mock class and call it in the +destructor, like this: + +```cpp +class MockFoo : public Foo { + ... + // Add the following two lines to the mock class. + MOCK_METHOD(void, Die, ()); + ~MockFoo() override { Die(); } +}; +``` + +(If the name `Die()` clashes with an existing symbol, choose another name.) Now, +we have translated the problem of testing when a `MockFoo` object dies to +testing when its `Die()` method is called: + +```cpp + MockFoo* foo = new MockFoo; + MockBar* bar = new MockBar; + ... + { + InSequence s; + + // Expects *foo to die after bar->A() and before bar->B(). + EXPECT_CALL(*bar, A()); + EXPECT_CALL(*foo, Die()); + EXPECT_CALL(*bar, B()); + } +``` + +And that's that. + +### Using gMock and Threads {#UsingThreads} + +In a **unit** test, it's best if you could isolate and test a piece of code in a +single-threaded context. That avoids race conditions and dead locks, and makes +debugging your test much easier. + +Yet most programs are multi-threaded, and sometimes to test something we need to +pound on it from more than one thread. gMock works for this purpose too. + +Remember the steps for using a mock: + +1. Create a mock object `foo`. +2. Set its default actions and expectations using `ON_CALL()` and + `EXPECT_CALL()`. +3. The code under test calls methods of `foo`. +4. Optionally, verify and reset the mock. +5. Destroy the mock yourself, or let the code under test destroy it. The + destructor will automatically verify it. + +If you follow the following simple rules, your mocks and threads can live +happily together: + +* Execute your *test code* (as opposed to the code being tested) in *one* + thread. This makes your test easy to follow. +* Obviously, you can do step #1 without locking. +* When doing step #2 and #5, make sure no other thread is accessing `foo`. + Obvious too, huh? +* #3 and #4 can be done either in one thread or in multiple threads - anyway + you want. gMock takes care of the locking, so you don't have to do any - + unless required by your test logic. + +If you violate the rules (for example, if you set expectations on a mock while +another thread is calling its methods), you get undefined behavior. That's not +fun, so don't do it. + +gMock guarantees that the action for a mock function is done in the same thread +that called the mock function. For example, in + +```cpp + EXPECT_CALL(mock, Foo(1)) + .WillOnce(action1); + EXPECT_CALL(mock, Foo(2)) + .WillOnce(action2); +``` + +if `Foo(1)` is called in thread 1 and `Foo(2)` is called in thread 2, gMock will +execute `action1` in thread 1 and `action2` in thread 2. + +gMock does *not* impose a sequence on actions performed in different threads +(doing so may create deadlocks as the actions may need to cooperate). This means +that the execution of `action1` and `action2` in the above example *may* +interleave. If this is a problem, you should add proper synchronization logic to +`action1` and `action2` to make the test thread-safe. + +Also, remember that `DefaultValue` is a global resource that potentially +affects *all* living mock objects in your program. Naturally, you won't want to +mess with it from multiple threads or when there still are mocks in action. + +### Controlling How Much Information gMock Prints + +When gMock sees something that has the potential of being an error (e.g. a mock +function with no expectation is called, a.k.a. an uninteresting call, which is +allowed but perhaps you forgot to explicitly ban the call), it prints some +warning messages, including the arguments of the function, the return value, and +the stack trace. Hopefully this will remind you to take a look and see if there +is indeed a problem. + +Sometimes you are confident that your tests are correct and may not appreciate +such friendly messages. Some other times, you are debugging your tests or +learning about the behavior of the code you are testing, and wish you could +observe every mock call that happens (including argument values, the return +value, and the stack trace). Clearly, one size doesn't fit all. + +You can control how much gMock tells you using the `--gmock_verbose=LEVEL` +command-line flag, where `LEVEL` is a string with three possible values: + +* `info`: gMock will print all informational messages, warnings, and errors + (most verbose). At this setting, gMock will also log any calls to the + `ON_CALL/EXPECT_CALL` macros. It will include a stack trace in + "uninteresting call" warnings. +* `warning`: gMock will print both warnings and errors (less verbose); it will + omit the stack traces in "uninteresting call" warnings. This is the default. +* `error`: gMock will print errors only (least verbose). + +Alternatively, you can adjust the value of that flag from within your tests like +so: + +```cpp + ::testing::FLAGS_gmock_verbose = "error"; +``` + +If you find gMock printing too many stack frames with its informational or +warning messages, remember that you can control their amount with the +`--gtest_stack_trace_depth=max_depth` flag. + +Now, judiciously use the right flag to enable gMock serve you better! + +### Gaining Super Vision into Mock Calls + +You have a test using gMock. It fails: gMock tells you some expectations aren't +satisfied. However, you aren't sure why: Is there a typo somewhere in the +matchers? Did you mess up the order of the `EXPECT_CALL`s? Or is the code under +test doing something wrong? How can you find out the cause? + +Won't it be nice if you have X-ray vision and can actually see the trace of all +`EXPECT_CALL`s and mock method calls as they are made? For each call, would you +like to see its actual argument values and which `EXPECT_CALL` gMock thinks it +matches? If you still need some help to figure out who made these calls, how +about being able to see the complete stack trace at each mock call? + +You can unlock this power by running your test with the `--gmock_verbose=info` +flag. For example, given the test program: + +```cpp +#include + +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::Return; + +class MockFoo { + public: + MOCK_METHOD(void, F, (const string& x, const string& y)); +}; + +TEST(Foo, Bar) { + MockFoo mock; + EXPECT_CALL(mock, F(_, _)).WillRepeatedly(Return()); + EXPECT_CALL(mock, F("a", "b")); + EXPECT_CALL(mock, F("c", HasSubstr("d"))); + + mock.F("a", "good"); + mock.F("a", "b"); +} +``` + +if you run it with `--gmock_verbose=info`, you will see this output: + +```shell +[ RUN ] Foo.Bar + +foo_test.cc:14: EXPECT_CALL(mock, F(_, _)) invoked +Stack trace: ... + +foo_test.cc:15: EXPECT_CALL(mock, F("a", "b")) invoked +Stack trace: ... + +foo_test.cc:16: EXPECT_CALL(mock, F("c", HasSubstr("d"))) invoked +Stack trace: ... + +foo_test.cc:14: Mock function call matches EXPECT_CALL(mock, F(_, _))... + Function call: F(@0x7fff7c8dad40"a",@0x7fff7c8dad10"good") +Stack trace: ... + +foo_test.cc:15: Mock function call matches EXPECT_CALL(mock, F("a", "b"))... + Function call: F(@0x7fff7c8dada0"a",@0x7fff7c8dad70"b") +Stack trace: ... + +foo_test.cc:16: Failure +Actual function call count doesn't match EXPECT_CALL(mock, F("c", HasSubstr("d")))... + Expected: to be called once + Actual: never called - unsatisfied and active +[ FAILED ] Foo.Bar +``` + +Suppose the bug is that the `"c"` in the third `EXPECT_CALL` is a typo and +should actually be `"a"`. With the above message, you should see that the actual +`F("a", "good")` call is matched by the first `EXPECT_CALL`, not the third as +you thought. From that it should be obvious that the third `EXPECT_CALL` is +written wrong. Case solved. + +If you are interested in the mock call trace but not the stack traces, you can +combine `--gmock_verbose=info` with `--gtest_stack_trace_depth=0` on the test +command line. + +### Running Tests in Emacs + +If you build and run your tests in Emacs using the `M-x google-compile` command +(as many googletest users do), the source file locations of gMock and googletest +errors will be highlighted. Just press `` on one of them and you'll be +taken to the offending line. Or, you can just type `C-x`` to jump to the next +error. + +To make it even easier, you can add the following lines to your `~/.emacs` file: + +```text +(global-set-key "\M-m" 'google-compile) ; m is for make +(global-set-key [M-down] 'next-error) +(global-set-key [M-up] '(lambda () (interactive) (next-error -1))) +``` + +Then you can type `M-m` to start a build (if you want to run the test as well, +just make sure `foo_test.run` or `runtests` is in the build command you supply +after typing `M-m`), or `M-up`/`M-down` to move back and forth between errors. + +## Extending gMock + +### Writing New Matchers Quickly {#NewMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be functionally pure. See +[this section](#PureMatchers) for more details. + +The `MATCHER*` family of macros can be used to define custom matchers easily. +The syntax: + +```cpp +MATCHER(name, description_string_expression) { statements; } +``` + +will define a matcher with the given name that executes the statements, which +must return a `bool` to indicate if the match succeeds. Inside the statements, +you can refer to the value being matched by `arg`, and refer to its type by +`arg_type`. + +The *description string* is a `string`-typed expression that documents what the +matcher does, and is used to generate the failure message when the match fails. +It can (and should) reference the special `bool` variable `negation`, and should +evaluate to the description of the matcher when `negation` is `false`, or that +of the matcher's negation when `negation` is `true`. + +For convenience, we allow the description string to be empty (`""`), in which +case gMock will use the sequence of words in the matcher name as the +description. + +For example: + +```cpp +MATCHER(IsDivisibleBy7, "") { return (arg % 7) == 0; } +``` + +allows you to write + +```cpp + // Expects mock_foo.Bar(n) to be called where n is divisible by 7. + EXPECT_CALL(mock_foo, Bar(IsDivisibleBy7())); +``` + +or, + +```cpp + using ::testing::Not; + ... + // Verifies that a value is divisible by 7 and the other is not. + EXPECT_THAT(some_expression, IsDivisibleBy7()); + EXPECT_THAT(some_other_expression, Not(IsDivisibleBy7())); +``` + +If the above assertions fail, they will print something like: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 + ... + Value of: some_other_expression + Expected: not (is divisible by 7) + Actual: 21 +``` + +where the descriptions `"is divisible by 7"` and `"not (is divisible by 7)"` are +automatically calculated from the matcher name `IsDivisibleBy7`. + +As you may have noticed, the auto-generated descriptions (especially those for +the negation) may not be so great. You can always override them with a `string` +expression of your own: + +```cpp +MATCHER(IsDivisibleBy7, + absl::StrCat(negation ? "isn't" : "is", " divisible by 7")) { + return (arg % 7) == 0; +} +``` + +Optionally, you can stream additional information to a hidden argument named +`result_listener` to explain the match result. For example, a better definition +of `IsDivisibleBy7` is: + +```cpp +MATCHER(IsDivisibleBy7, "") { + if ((arg % 7) == 0) + return true; + + *result_listener << "the remainder is " << (arg % 7); + return false; +} +``` + +With this definition, the above assertion will give a better message: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 (the remainder is 6) +``` + +You should let `MatchAndExplain()` print *any additional information* that can +help a user understand the match result. Note that it should explain why the +match succeeds in case of a success (unless it's obvious) - this is useful when +the matcher is used inside `Not()`. There is no need to print the argument value +itself, as gMock already prints it for you. + +{: .callout .note} +NOTE: The type of the value being matched (`arg_type`) is determined by the +context in which you use the matcher and is supplied to you by the compiler, so +you don't need to worry about declaring it (nor can you). This allows the +matcher to be polymorphic. For example, `IsDivisibleBy7()` can be used to match +any type where the value of `(arg % 7) == 0` can be implicitly converted to a +`bool`. In the `Bar(IsDivisibleBy7())` example above, if method `Bar()` takes an +`int`, `arg_type` will be `int`; if it takes an `unsigned long`, `arg_type` will +be `unsigned long`; and so on. + +### Writing New Parameterized Matchers Quickly + +Sometimes you'll want to define a matcher that has parameters. For that you can +use the macro: + +```cpp +MATCHER_P(name, param_name, description_string) { statements; } +``` + +where the description string can be either `""` or a `string` expression that +references `negation` and `param_name`. + +For example: + +```cpp +MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +``` + +will allow you to write: + +```cpp + EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +``` + +which may lead to this message (assuming `n` is 10): + +```shell + Value of: Blah("a") + Expected: has absolute value 10 + Actual: -9 +``` + +Note that both the matcher description and its parameter are printed, making the +message human-friendly. + +In the matcher definition body, you can write `foo_type` to reference the type +of a parameter named `foo`. For example, in the body of +`MATCHER_P(HasAbsoluteValue, value)` above, you can write `value_type` to refer +to the type of `value`. + +gMock also provides `MATCHER_P2`, `MATCHER_P3`, ..., up to `MATCHER_P10` to +support multi-parameter matchers: + +```cpp +MATCHER_Pk(name, param_1, ..., param_k, description_string) { statements; } +``` + +Please note that the custom description string is for a particular *instance* of +the matcher, where the parameters have been bound to actual values. Therefore +usually you'll want the parameter values to be part of the description. gMock +lets you do that by referencing the matcher parameters in the description string +expression. + +For example, + +```cpp +using ::testing::PrintToString; +MATCHER_P2(InClosedRange, low, hi, + absl::StrFormat("%s in range [%s, %s]", negation ? "isn't" : "is", + PrintToString(low), PrintToString(hi))) { + return low <= arg && arg <= hi; +} +... +EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the message: + +```shell + Expected: is in range [4, 6] +``` + +If you specify `""` as the description, the failure message will contain the +sequence of words in the matcher name followed by the parameter values printed +as a tuple. For example, + +```cpp + MATCHER_P2(InClosedRange, low, hi, "") { ... } + ... + EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the text: + +```shell + Expected: in closed range (4, 6) +``` + +For the purpose of typing, you can view + +```cpp +MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +``` + +as shorthand for + +```cpp +template +FooMatcherPk +Foo(p1_type p1, ..., pk_type pk) { ... } +``` + +When you write `Foo(v1, ..., vk)`, the compiler infers the types of the +parameters `v1`, ..., and `vk` for you. If you are not happy with the result of +the type inference, you can specify the types by explicitly instantiating the +template, as in `Foo(5, false)`. As said earlier, you don't get to +(or need to) specify `arg_type` as that's determined by the context in which the +matcher is used. + +You can assign the result of expression `Foo(p1, ..., pk)` to a variable of type +`FooMatcherPk`. This can be useful when composing +matchers. Matchers that don't have a parameter or have only one parameter have +special types: you can assign `Foo()` to a `FooMatcher`-typed variable, and +assign `Foo(p)` to a `FooMatcherP`-typed variable. + +While you can instantiate a matcher template with reference types, passing the +parameters by pointer usually makes your code more readable. If, however, you +still want to pass a parameter by reference, be aware that in the failure +message generated by the matcher you will see the value of the referenced object +but not its address. + +You can overload matchers with different numbers of parameters: + +```cpp +MATCHER_P(Blah, a, description_string_1) { ... } +MATCHER_P2(Blah, a, b, description_string_2) { ... } +``` + +While it's tempting to always use the `MATCHER*` macros when defining a new +matcher, you should also consider implementing the matcher interface directly +instead (see the recipes that follow), especially if you need to use the matcher +a lot. While these approaches require more work, they give you more control on +the types of the value being matched and the matcher parameters, which in +general leads to better compiler error messages that pay off in the long run. +They also allow overloading matchers based on parameter types (as opposed to +just based on the number of parameters). + +### Writing New Monomorphic Matchers + +A matcher of argument type `T` implements the matcher interface for `T` and does +two things: it tests whether a value of type `T` matches the matcher, and can +describe what kind of values it matches. The latter ability is used for +generating readable error messages when expectations are violated. + +A matcher of `T` must declare a typedef like: + +```cpp +using is_gtest_matcher = void; +``` + +and supports the following operations: + +```cpp +// Match a value and optionally explain into an ostream. +bool matched = matcher.MatchAndExplain(value, maybe_os); +// where `value` is of type `T` and +// `maybe_os` is of type `std::ostream*`, where it can be null if the caller +// is not interested in there textual explanation. + +matcher.DescribeTo(os); +matcher.DescribeNegationTo(os); +// where `os` is of type `std::ostream*`. +``` + +If you need a custom matcher but `Truly()` is not a good option (for example, +you may not be happy with the way `Truly(predicate)` describes itself, or you +may want your matcher to be polymorphic as `Eq(value)` is), you can define a +matcher to do whatever you want in two steps: first implement the matcher +interface, and then define a factory function to create a matcher instance. The +second step is not strictly needed but it makes the syntax of using the matcher +nicer. + +For example, you can define a matcher to test whether an `int` is divisible by 7 +and then use it like this: + +```cpp +using ::testing::Matcher; + +class DivisibleBy7Matcher { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(int n, std::ostream*) const { + return (n % 7) == 0; + } + + void DescribeTo(std::ostream* os) const { + *os << "is divisible by 7"; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is not divisible by 7"; + } +}; + +Matcher DivisibleBy7() { + return DivisibleBy7Matcher(); +} + +... + EXPECT_CALL(foo, Bar(DivisibleBy7())); +``` + +You may improve the matcher message by streaming additional information to the +`os` argument in `MatchAndExplain()`: + +```cpp +class DivisibleBy7Matcher { + public: + bool MatchAndExplain(int n, std::ostream* os) const { + const int remainder = n % 7; + if (remainder != 0 && os != nullptr) { + *os << "the remainder is " << remainder; + } + return remainder == 0; + } + ... +}; +``` + +Then, `EXPECT_THAT(x, DivisibleBy7());` may generate a message like this: + +```shell +Value of: x +Expected: is divisible by 7 + Actual: 23 (the remainder is 2) +``` + +{: .callout .tip} +Tip: for convenience, `MatchAndExplain()` can take a `MatchResultListener*` +instead of `std::ostream*`. + +### Writing New Polymorphic Matchers + +Expanding what we learned above to *polymorphic* matchers is now just as simple +as adding templates in the right place. + +```cpp + +class NotNullMatcher { + public: + using is_gtest_matcher = void; + + // To implement a polymorphic matcher, we just need to make MatchAndExplain a + // template on its first argument. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, std::ostream*) const { + return p != nullptr; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +NotNullMatcher NotNull() { + return NotNullMatcher(); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +### Legacy Matcher Implementation + +Defining matchers used to be somewhat more complicated, in which it required +several supporting classes and virtual functions. To implement a matcher for +type `T` using the legacy API you have to derive from `MatcherInterface` and +call `MakeMatcher` to construct the object. + +The interface looks like this: + +```cpp +class MatchResultListener { + public: + ... + // Streams x to the underlying ostream; does nothing if the ostream + // is NULL. + template + MatchResultListener& operator<<(const T& x); + + // Returns the underlying ostream. + std::ostream* stream(); +}; + +template +class MatcherInterface { + public: + virtual ~MatcherInterface(); + + // Returns true if and only if the matcher matches x; also explains the match + // result to 'listener'. + virtual bool MatchAndExplain(T x, MatchResultListener* listener) const = 0; + + // Describes this matcher to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; + + // Describes the negation of this matcher to an ostream. + virtual void DescribeNegationTo(std::ostream* os) const; +}; +``` + +Fortunately, most of the time you can define a polymorphic matcher easily with +the help of `MakePolymorphicMatcher()`. Here's how you can define `NotNull()` as +an example: + +```cpp +using ::testing::MakePolymorphicMatcher; +using ::testing::MatchResultListener; +using ::testing::PolymorphicMatcher; + +class NotNullMatcher { + public: + // To implement a polymorphic matcher, first define a COPYABLE class + // that has three members MatchAndExplain(), DescribeTo(), and + // DescribeNegationTo(), like the following. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, + MatchResultListener* /* listener */) const { + return p != NULL; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +// To construct a polymorphic matcher, pass an instance of the class +// to MakePolymorphicMatcher(). Note the return type. +PolymorphicMatcher NotNull() { + return MakePolymorphicMatcher(NotNullMatcher()); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +{: .callout .note} +**Note:** Your polymorphic matcher class does **not** need to inherit from +`MatcherInterface` or any other class, and its methods do **not** need to be +virtual. + +Like in a monomorphic matcher, you may explain the match result by streaming +additional information to the `listener` argument in `MatchAndExplain()`. + +### Writing New Cardinalities + +A cardinality is used in `Times()` to tell gMock how many times you expect a +call to occur. It doesn't have to be exact. For example, you can say +`AtLeast(5)` or `Between(2, 4)`. + +If the [built-in set](gmock_cheat_sheet.md#CardinalityList) of cardinalities +doesn't suit you, you are free to define your own by implementing the following +interface (in namespace `testing`): + +```cpp +class CardinalityInterface { + public: + virtual ~CardinalityInterface(); + + // Returns true if and only if call_count calls will satisfy this cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; +}; +``` + +For example, to specify that a call must occur even number of times, you can +write + +```cpp +using ::testing::Cardinality; +using ::testing::CardinalityInterface; +using ::testing::MakeCardinality; + +class EvenNumberCardinality : public CardinalityInterface { + public: + bool IsSatisfiedByCallCount(int call_count) const override { + return (call_count % 2) == 0; + } + + bool IsSaturatedByCallCount(int call_count) const override { + return false; + } + + void DescribeTo(std::ostream* os) const { + *os << "called even number of times"; + } +}; + +Cardinality EvenNumber() { + return MakeCardinality(new EvenNumberCardinality); +} + +... + EXPECT_CALL(foo, Bar(3)) + .Times(EvenNumber()); +``` + +### Writing New Actions {#QuickNewActions} + +If the built-in actions don't work for you, you can easily define your own one. +All you need is a call operator with a signature compatible with the mocked +function. So you can use a lambda: + +```cpp +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([](const int input) { return input * 7; }); +EXPECT_EQ(mock.AsStdFunction()(2), 14); +``` + +Or a struct with a call operator (even a templated one): + +```cpp +struct MultiplyBy { + template + T operator()(T arg) { return arg * multiplier; } + + int multiplier; +}; + +// Then use: +// EXPECT_CALL(...).WillOnce(MultiplyBy{7}); +``` + +It's also fine for the callable to take no arguments, ignoring the arguments +supplied to the mock function: + +```cpp +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([] { return 17; }); +EXPECT_EQ(mock.AsStdFunction()(0), 17); +``` + +When used with `WillOnce`, the callable can assume it will be called at most +once and is allowed to be a move-only type: + +```cpp +// An action that contains move-only types and has an &&-qualified operator, +// demanding in the type system that it be called at most once. This can be +// used with WillOnce, but the compiler will reject it if handed to +// WillRepeatedly. +struct MoveOnlyAction { + std::unique_ptr move_only_state; + std::unique_ptr operator()() && { return std::move(move_only_state); } +}; + +MockFunction()> mock; +EXPECT_CALL(mock, Call).WillOnce(MoveOnlyAction{std::make_unique(17)}); +EXPECT_THAT(mock.AsStdFunction()(), Pointee(Eq(17))); +``` + +More generally, to use with a mock function whose signature is `R(Args...)` the +object can be anything convertible to `OnceAction` or +`Action. The difference between the two is that `OnceAction` has +weaker requirements (`Action` requires a copy-constructible input that can be +called repeatedly whereas `OnceAction` requires only move-constructible and +supports `&&`-qualified call operators), but can be used only with `WillOnce`. +`OnceAction` is typically relevant only when supporting move-only types or +actions that want a type-system guarantee that they will be called at most once. + +Typically the `OnceAction` and `Action` templates need not be referenced +directly in your actions: a struct or class with a call operator is sufficient, +as in the examples above. But fancier polymorphic actions that need to know the +specific return type of the mock function can define templated conversion +operators to make that possible. See `gmock-actions.h` for examples. + +#### Legacy macro-based Actions + +Before C++11, the functor-based actions were not supported; the old way of +writing actions was through a set of `ACTION*` macros. We suggest to avoid them +in new code; they hide a lot of logic behind the macro, potentially leading to +harder-to-understand compiler errors. Nevertheless, we cover them here for +completeness. + +By writing + +```cpp +ACTION(name) { statements; } +``` + +in a namespace scope (i.e. not inside a class or function), you will define an +action with the given name that executes the statements. The value returned by +`statements` will be used as the return value of the action. Inside the +statements, you can refer to the K-th (0-based) argument of the mock function as +`argK`. For example: + +```cpp +ACTION(IncrementArg1) { return ++(*arg1); } +``` + +allows you to write + +```cpp +... WillOnce(IncrementArg1()); +``` + +Note that you don't need to specify the types of the mock function arguments. +Rest assured that your code is type-safe though: you'll get a compiler error if +`*arg1` doesn't support the `++` operator, or if the type of `++(*arg1)` isn't +compatible with the mock function's return type. + +Another example: + +```cpp +ACTION(Foo) { + (*arg2)(5); + Blah(); + *arg1 = 0; + return arg0; +} +``` + +defines an action `Foo()` that invokes argument #2 (a function pointer) with 5, +calls function `Blah()`, sets the value pointed to by argument #1 to 0, and +returns argument #0. + +For more convenience and flexibility, you can also use the following pre-defined +symbols in the body of `ACTION`: + +`argK_type` | The type of the K-th (0-based) argument of the mock function +:-------------- | :----------------------------------------------------------- +`args` | All arguments of the mock function as a tuple +`args_type` | The type of all arguments of the mock function as a tuple +`return_type` | The return type of the mock function +`function_type` | The type of the mock function + +For example, when using an `ACTION` as a stub action for mock function: + +```cpp +int DoSomething(bool flag, int* ptr); +``` + +we have: + +Pre-defined Symbol | Is Bound To +------------------ | --------------------------------- +`arg0` | the value of `flag` +`arg0_type` | the type `bool` +`arg1` | the value of `ptr` +`arg1_type` | the type `int*` +`args` | the tuple `(flag, ptr)` +`args_type` | the type `std::tuple` +`return_type` | the type `int` +`function_type` | the type `int(bool, int*)` + +#### Legacy macro-based parameterized Actions + +Sometimes you'll want to parameterize an action you define. For that we have +another macro + +```cpp +ACTION_P(name, param) { statements; } +``` + +For example, + +```cpp +ACTION_P(Add, n) { return arg0 + n; } +``` + +will allow you to write + +```cpp +// Returns argument #0 + 5. +... WillOnce(Add(5)); +``` + +For convenience, we use the term *arguments* for the values used to invoke the +mock function, and the term *parameters* for the values used to instantiate an +action. + +Note that you don't need to provide the type of the parameter either. Suppose +the parameter is named `param`, you can also use the gMock-defined symbol +`param_type` to refer to the type of the parameter as inferred by the compiler. +For example, in the body of `ACTION_P(Add, n)` above, you can write `n_type` for +the type of `n`. + +gMock also provides `ACTION_P2`, `ACTION_P3`, and etc to support multi-parameter +actions. For example, + +```cpp +ACTION_P2(ReturnDistanceTo, x, y) { + double dx = arg0 - x; + double dy = arg1 - y; + return sqrt(dx*dx + dy*dy); +} +``` + +lets you write + +```cpp +... WillOnce(ReturnDistanceTo(5.0, 26.5)); +``` + +You can view `ACTION` as a degenerated parameterized action where the number of +parameters is 0. + +You can also easily define actions overloaded on the number of parameters: + +```cpp +ACTION_P(Plus, a) { ... } +ACTION_P2(Plus, a, b) { ... } +``` + +### Restricting the Type of an Argument or Parameter in an ACTION + +For maximum brevity and reusability, the `ACTION*` macros don't ask you to +provide the types of the mock function arguments and the action parameters. +Instead, we let the compiler infer the types for us. + +Sometimes, however, we may want to be more explicit about the types. There are +several tricks to do that. For example: + +```cpp +ACTION(Foo) { + // Makes sure arg0 can be converted to int. + int n = arg0; + ... use n instead of arg0 here ... +} + +ACTION_P(Bar, param) { + // Makes sure the type of arg1 is const char*. + ::testing::StaticAssertTypeEq(); + + // Makes sure param can be converted to bool. + bool flag = param; +} +``` + +where `StaticAssertTypeEq` is a compile-time assertion in googletest that +verifies two types are the same. + +### Writing New Action Templates Quickly + +Sometimes you want to give an action explicit template parameters that cannot be +inferred from its value parameters. `ACTION_TEMPLATE()` supports that and can be +viewed as an extension to `ACTION()` and `ACTION_P*()`. + +The syntax: + +```cpp +ACTION_TEMPLATE(ActionName, + HAS_m_TEMPLATE_PARAMS(kind1, name1, ..., kind_m, name_m), + AND_n_VALUE_PARAMS(p1, ..., p_n)) { statements; } +``` + +defines an action template that takes *m* explicit template parameters and *n* +value parameters, where *m* is in [1, 10] and *n* is in [0, 10]. `name_i` is the +name of the *i*-th template parameter, and `kind_i` specifies whether it's a +`typename`, an integral constant, or a template. `p_i` is the name of the *i*-th +value parameter. + +Example: + +```cpp +// DuplicateArg(output) converts the k-th argument of the mock +// function to type T and copies it to *output. +ACTION_TEMPLATE(DuplicateArg, + // Note the comma between int and k: + HAS_2_TEMPLATE_PARAMS(int, k, typename, T), + AND_1_VALUE_PARAMS(output)) { + *output = T(std::get(args)); +} +``` + +To create an instance of an action template, write: + +```cpp +ActionName(v1, ..., v_n) +``` + +where the `t`s are the template arguments and the `v`s are the value arguments. +The value argument types are inferred by the compiler. For example: + +```cpp +using ::testing::_; +... + int n; + EXPECT_CALL(mock, Foo).WillOnce(DuplicateArg<1, unsigned char>(&n)); +``` + +If you want to explicitly specify the value argument types, you can provide +additional template arguments: + +```cpp +ActionName(v1, ..., v_n) +``` + +where `u_i` is the desired type of `v_i`. + +`ACTION_TEMPLATE` and `ACTION`/`ACTION_P*` can be overloaded on the number of +value parameters, but not on the number of template parameters. Without the +restriction, the meaning of the following is unclear: + +```cpp + OverloadedAction(x); +``` + +Are we using a single-template-parameter action where `bool` refers to the type +of `x`, or a two-template-parameter action where the compiler is asked to infer +the type of `x`? + +### Using the ACTION Object's Type + +If you are writing a function that returns an `ACTION` object, you'll need to +know its type. The type depends on the macro used to define the action and the +parameter types. The rule is relatively simple: + + +| Given Definition | Expression | Has Type | +| ----------------------------- | ------------------- | --------------------- | +| `ACTION(Foo)` | `Foo()` | `FooAction` | +| `ACTION_TEMPLATE(Foo, HAS_m_TEMPLATE_PARAMS(...), AND_0_VALUE_PARAMS())` | `Foo()` | `FooAction` | +| `ACTION_P(Bar, param)` | `Bar(int_value)` | `BarActionP` | +| `ACTION_TEMPLATE(Bar, HAS_m_TEMPLATE_PARAMS(...), AND_1_VALUE_PARAMS(p1))` | `Bar(int_value)` | `BarActionP` | +| `ACTION_P2(Baz, p1, p2)` | `Baz(bool_value, int_value)` | `BazActionP2` | +| `ACTION_TEMPLATE(Baz, HAS_m_TEMPLATE_PARAMS(...), AND_2_VALUE_PARAMS(p1, p2))` | `Baz(bool_value, int_value)` | `BazActionP2` | +| ... | ... | ... | + + +Note that we have to pick different suffixes (`Action`, `ActionP`, `ActionP2`, +and etc) for actions with different numbers of value parameters, or the action +definitions cannot be overloaded on the number of them. + +### Writing New Monomorphic Actions {#NewMonoActions} + +While the `ACTION*` macros are very convenient, sometimes they are +inappropriate. For example, despite the tricks shown in the previous recipes, +they don't let you directly specify the types of the mock function arguments and +the action parameters, which in general leads to unoptimized compiler error +messages that can baffle unfamiliar users. They also don't allow overloading +actions based on parameter types without jumping through some hoops. + +An alternative to the `ACTION*` macros is to implement +`::testing::ActionInterface`, where `F` is the type of the mock function in +which the action will be used. For example: + +```cpp +template +class ActionInterface { + public: + virtual ~ActionInterface(); + + // Performs the action. Result is the return type of function type + // F, and ArgumentTuple is the tuple of arguments of F. + // + + // For example, if F is int(bool, const string&), then Result would + // be int, and ArgumentTuple would be std::tuple. + virtual Result Perform(const ArgumentTuple& args) = 0; +}; +``` + +```cpp +using ::testing::_; +using ::testing::Action; +using ::testing::ActionInterface; +using ::testing::MakeAction; + +typedef int IncrementMethod(int*); + +class IncrementArgumentAction : public ActionInterface { + public: + int Perform(const std::tuple& args) override { + int* p = std::get<0>(args); // Grabs the first argument. + return *p++; + } +}; + +Action IncrementArgument() { + return MakeAction(new IncrementArgumentAction); +} + +... + EXPECT_CALL(foo, Baz(_)) + .WillOnce(IncrementArgument()); + + int n = 5; + foo.Baz(&n); // Should return 5 and change n to 6. +``` + +### Writing New Polymorphic Actions {#NewPolyActions} + +The previous recipe showed you how to define your own action. This is all good, +except that you need to know the type of the function in which the action will +be used. Sometimes that can be a problem. For example, if you want to use the +action in functions with *different* types (e.g. like `Return()` and +`SetArgPointee()`). + +If an action can be used in several types of mock functions, we say it's +*polymorphic*. The `MakePolymorphicAction()` function template makes it easy to +define such an action: + +```cpp +namespace testing { +template +PolymorphicAction MakePolymorphicAction(const Impl& impl); +} // namespace testing +``` + +As an example, let's define an action that returns the second argument in the +mock function's argument list. The first step is to define an implementation +class: + +```cpp +class ReturnSecondArgumentAction { + public: + template + Result Perform(const ArgumentTuple& args) const { + // To get the i-th (0-based) argument, use std::get(args). + return std::get<1>(args); + } +}; +``` + +This implementation class does *not* need to inherit from any particular class. +What matters is that it must have a `Perform()` method template. This method +template takes the mock function's arguments as a tuple in a **single** +argument, and returns the result of the action. It can be either `const` or not, +but must be invocable with exactly one template argument, which is the result +type. In other words, you must be able to call `Perform(args)` where `R` is +the mock function's return type and `args` is its arguments in a tuple. + +Next, we use `MakePolymorphicAction()` to turn an instance of the implementation +class into the polymorphic action we need. It will be convenient to have a +wrapper for this: + +```cpp +using ::testing::MakePolymorphicAction; +using ::testing::PolymorphicAction; + +PolymorphicAction ReturnSecondArgument() { + return MakePolymorphicAction(ReturnSecondArgumentAction()); +} +``` + +Now, you can use this polymorphic action the same way you use the built-in ones: + +```cpp +using ::testing::_; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, DoThis, (bool flag, int n), (override)); + MOCK_METHOD(string, DoThat, (int x, const char* str1, const char* str2), + (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, DoThis).WillOnce(ReturnSecondArgument()); + EXPECT_CALL(foo, DoThat).WillOnce(ReturnSecondArgument()); + ... + foo.DoThis(true, 5); // Will return 5. + foo.DoThat(1, "Hi", "Bye"); // Will return "Hi". +``` + +### Teaching gMock How to Print Your Values + +When an uninteresting or unexpected call occurs, gMock prints the argument +values and the stack trace to help you debug. Assertion macros like +`EXPECT_THAT` and `EXPECT_EQ` also print the values in question when the +assertion fails. gMock and googletest do this using googletest's user-extensible +value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. +[The GoogleTest advanced guide](advanced.md#teaching-googletest-how-to-print-your-values) +explains how to extend the printer to do a better job at printing your +particular type than to dump the bytes. + +## Useful Mocks Created Using gMock + + + + +### Mock std::function {#MockFunction} + +`std::function` is a general function type introduced in C++11. It is a +preferred way of passing callbacks to new interfaces. Functions are copyable, +and are not usually passed around by pointer, which makes them tricky to mock. +But fear not - `MockFunction` can help you with that. + +`MockFunction` has a mock method `Call()` with the signature: + +```cpp + R Call(T1, ..., Tn); +``` + +It also has a `AsStdFunction()` method, which creates a `std::function` proxy +forwarding to Call: + +```cpp + std::function AsStdFunction(); +``` + +To use `MockFunction`, first create `MockFunction` object and set up +expectations on its `Call` method. Then pass proxy obtained from +`AsStdFunction()` to the code you are testing. For example: + +```cpp +TEST(FooTest, RunsCallbackWithBarArgument) { + // 1. Create a mock object. + MockFunction mock_function; + + // 2. Set expectations on Call() method. + EXPECT_CALL(mock_function, Call("bar")).WillOnce(Return(1)); + + // 3. Exercise code that uses std::function. + Foo(mock_function.AsStdFunction()); + // Foo's signature can be either of: + // void Foo(const std::function& fun); + // void Foo(std::function fun); + + // 4. All expectations will be verified when mock_function + // goes out of scope and is destroyed. +} +``` + +Remember that function objects created with `AsStdFunction()` are just +forwarders. If you create multiple of them, they will share the same set of +expectations. + +Although `std::function` supports unlimited number of arguments, `MockFunction` +implementation is limited to ten. If you ever hit that limit... well, your +callback has bigger problems than being mockable. :-) diff --git a/third_party/googletest/docs/gmock_faq.md b/third_party/googletest/docs/gmock_faq.md new file mode 100644 index 0000000..8f220bf --- /dev/null +++ b/third_party/googletest/docs/gmock_faq.md @@ -0,0 +1,390 @@ +# Legacy gMock FAQ + +### When I call a method on my mock object, the method for the real object is invoked instead. What's the problem? + +In order for a method to be mocked, it must be *virtual*, unless you use the +[high-perf dependency injection technique](gmock_cook_book.md#MockingNonVirtualMethods). + +### Can I mock a variadic function? + +You cannot mock a variadic function (i.e. a function taking ellipsis (`...`) +arguments) directly in gMock. + +The problem is that in general, there is *no way* for a mock object to know how +many arguments are passed to the variadic method, and what the arguments' types +are. Only the *author of the base class* knows the protocol, and we cannot look +into his or her head. + +Therefore, to mock such a function, the *user* must teach the mock object how to +figure out the number of arguments and their types. One way to do it is to +provide overloaded versions of the function. + +Ellipsis arguments are inherited from C and not really a C++ feature. They are +unsafe to use and don't work with arguments that have constructors or +destructors. Therefore we recommend to avoid them in C++ as much as possible. + +### MSVC gives me warning C4301 or C4373 when I define a mock method with a const parameter. Why? + +If you compile this using Microsoft Visual C++ 2005 SP1: + +```cpp +class Foo { + ... + virtual void Bar(const int i) = 0; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(void, Bar, (const int i), (override)); +}; +``` + +You may get the following warning: + +```shell +warning C4301: 'MockFoo::Bar': overriding virtual function only differs from 'Foo::Bar' by const/volatile qualifier +``` + +This is a MSVC bug. The same code compiles fine with gcc, for example. If you +use Visual C++ 2008 SP1, you would get the warning: + +```shell +warning C4373: 'MockFoo::Bar': virtual function overrides 'Foo::Bar', previous versions of the compiler did not override when parameters only differed by const/volatile qualifiers +``` + +In C++, if you *declare* a function with a `const` parameter, the `const` +modifier is ignored. Therefore, the `Foo` base class above is equivalent to: + +```cpp +class Foo { + ... + virtual void Bar(int i) = 0; // int or const int? Makes no difference. +}; +``` + +In fact, you can *declare* `Bar()` with an `int` parameter, and define it with a +`const int` parameter. The compiler will still match them up. + +Since making a parameter `const` is meaningless in the method declaration, we +recommend to remove it in both `Foo` and `MockFoo`. That should workaround the +VC bug. + +Note that we are talking about the *top-level* `const` modifier here. If the +function parameter is passed by pointer or reference, declaring the pointee or +referee as `const` is still meaningful. For example, the following two +declarations are *not* equivalent: + +```cpp +void Bar(int* p); // Neither p nor *p is const. +void Bar(const int* p); // p is not const, but *p is. +``` + +### I can't figure out why gMock thinks my expectations are not satisfied. What should I do? + +You might want to run your test with `--gmock_verbose=info`. This flag lets +gMock print a trace of every mock function call it receives. By studying the +trace, you'll gain insights on why the expectations you set are not met. + +If you see the message "The mock function has no default action set, and its +return type has no default value set.", then try +[adding a default action](gmock_cheat_sheet.md#OnCall). Due to a known issue, +unexpected calls on mocks without default actions don't print out a detailed +comparison between the actual arguments and the expected arguments. + +### My program crashed and `ScopedMockLog` spit out tons of messages. Is it a gMock bug? + +gMock and `ScopedMockLog` are likely doing the right thing here. + +When a test crashes, the failure signal handler will try to log a lot of +information (the stack trace, and the address map, for example). The messages +are compounded if you have many threads with depth stacks. When `ScopedMockLog` +intercepts these messages and finds that they don't match any expectations, it +prints an error for each of them. + +You can learn to ignore the errors, or you can rewrite your expectations to make +your test more robust, for example, by adding something like: + +```cpp +using ::testing::AnyNumber; +using ::testing::Not; +... + // Ignores any log not done by us. + EXPECT_CALL(log, Log(_, Not(EndsWith("/my_file.cc")), _)) + .Times(AnyNumber()); +``` + +### How can I assert that a function is NEVER called? + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +### I have a failed test where gMock tells me TWICE that a particular expectation is not satisfied. Isn't this redundant? + +When gMock detects a failure, it prints relevant information (the mock function +arguments, the state of relevant expectations, and etc) to help the user debug. +If another failure is detected, gMock will do the same, including printing the +state of relevant expectations. + +Sometimes an expectation's state didn't change between two failures, and you'll +see the same description of the state twice. They are however *not* redundant, +as they refer to *different points in time*. The fact they are the same *is* +interesting information. + +### I get a heapcheck failure when using a mock object, but using a real object is fine. What can be wrong? + +Does the class (hopefully a pure interface) you are mocking have a virtual +destructor? + +Whenever you derive from a base class, make sure its destructor is virtual. +Otherwise Bad Things will happen. Consider the following code: + +```cpp +class Base { + public: + // Not virtual, but should be. + ~Base() { ... } + ... +}; + +class Derived : public Base { + public: + ... + private: + std::string value_; +}; + +... + Base* p = new Derived; + ... + delete p; // Surprise! ~Base() will be called, but ~Derived() will not + // - value_ is leaked. +``` + +By changing `~Base()` to virtual, `~Derived()` will be correctly called when +`delete p` is executed, and the heap checker will be happy. + +### The "newer expectations override older ones" rule makes writing expectations awkward. Why does gMock do that? + +When people complain about this, often they are referring to code like: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. However, I have to write the expectations in the + // reverse order. This sucks big time!!! + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); +``` + +The problem, is that they didn't pick the **best** way to express the test's +intent. + +By default, expectations don't have to be matched in *any* particular order. If +you want them to match in a certain order, you need to be explicit. This is +gMock's (and jMock's) fundamental philosophy: it's easy to accidentally +over-specify your tests, and we want to make it harder to do so. + +There are two better ways to write the test spec. You could either put the +expectations in sequence: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. Using a sequence, we can write the expectations + // in their natural order. + { + InSequence s; + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + } +``` + +or you can put the sequence of actions in the same expectation: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .RetiresOnSaturation(); +``` + +Back to the original questions: why does gMock search the expectations (and +`ON_CALL`s) from back to front? Because this allows a user to set up a mock's +behavior for the common case early (e.g. in the mock's constructor or the test +fixture's set-up phase) and customize it with more specific rules later. If +gMock searches from front to back, this very useful pattern won't be possible. + +### gMock prints a warning when a function without EXPECT_CALL is called, even if I have set its behavior using ON_CALL. Would it be reasonable not to show the warning in this case? + +When choosing between being neat and being safe, we lean toward the latter. So +the answer is that we think it's better to show the warning. + +Often people write `ON_CALL`s in the mock object's constructor or `SetUp()`, as +the default behavior rarely changes from test to test. Then in the test body +they set the expectations, which are often different for each test. Having an +`ON_CALL` in the set-up part of a test doesn't mean that the calls are expected. +If there's no `EXPECT_CALL` and the method is called, it's possibly an error. If +we quietly let the call go through without notifying the user, bugs may creep in +unnoticed. + +If, however, you are sure that the calls are OK, you can write + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .WillRepeatedly(...); +``` + +instead of + +```cpp +using ::testing::_; +... + ON_CALL(foo, Bar(_)) + .WillByDefault(...); +``` + +This tells gMock that you do expect the calls and no warning should be printed. + +Also, you can control the verbosity by specifying `--gmock_verbose=error`. Other +values are `info` and `warning`. If you find the output too noisy when +debugging, just choose a less verbose level. + +### How can I delete the mock function's argument in an action? + +If your mock function takes a pointer argument and you want to delete that +argument, you can use testing::DeleteArg() to delete the N'th (zero-indexed) +argument: + +```cpp +using ::testing::_; + ... + MOCK_METHOD(void, Bar, (X* x, const Y& y)); + ... + EXPECT_CALL(mock_foo_, Bar(_, _)) + .WillOnce(testing::DeleteArg<0>())); +``` + +### How can I perform an arbitrary action on a mock function's argument? + +If you find yourself needing to perform some action that's not supported by +gMock directly, remember that you can define your own actions using +[`MakeAction()`](#NewMonoActions) or +[`MakePolymorphicAction()`](#NewPolyActions), or you can write a stub function +and invoke it using [`Invoke()`](#FunctionsAsActions). + +```cpp +using ::testing::_; +using ::testing::Invoke; + ... + MOCK_METHOD(void, Bar, (X* p)); + ... + EXPECT_CALL(mock_foo_, Bar(_)) + .WillOnce(Invoke(MyAction(...))); +``` + +### My code calls a static/global function. Can I mock it? + +You can, but you need to make some changes. + +In general, if you find yourself needing to mock a static function, it's a sign +that your modules are too tightly coupled (and less flexible, less reusable, +less testable, etc). You are probably better off defining a small interface and +call the function through that interface, which then can be easily mocked. It's +a bit of work initially, but usually pays for itself quickly. + +This Google Testing Blog +[post](https://testing.googleblog.com/2008/06/defeat-static-cling.html) says it +excellently. Check it out. + +### My mock object needs to do complex stuff. It's a lot of pain to specify the actions. gMock sucks! + +I know it's not a question, but you get an answer for free any way. :-) + +With gMock, you can create mocks in C++ easily. And people might be tempted to +use them everywhere. Sometimes they work great, and sometimes you may find them, +well, a pain to use. So, what's wrong in the latter case? + +When you write a test without using mocks, you exercise the code and assert that +it returns the correct value or that the system is in an expected state. This is +sometimes called "state-based testing". + +Mocks are great for what some call "interaction-based" testing: instead of +checking the system state at the very end, mock objects verify that they are +invoked the right way and report an error as soon as it arises, giving you a +handle on the precise context in which the error was triggered. This is often +more effective and economical to do than state-based testing. + +If you are doing state-based testing and using a test double just to simulate +the real object, you are probably better off using a fake. Using a mock in this +case causes pain, as it's not a strong point for mocks to perform complex +actions. If you experience this and think that mocks suck, you are just not +using the right tool for your problem. Or, you might be trying to solve the +wrong problem. :-) + +### I got a warning "Uninteresting function call encountered - default action taken.." Should I panic? + +By all means, NO! It's just an FYI. :-) + +What it means is that you have a mock function, you haven't set any expectations +on it (by gMock's rule this means that you are not interested in calls to this +function and therefore it can be called any number of times), and it is called. +That's OK - you didn't say it's not OK to call the function! + +What if you actually meant to disallow this function to be called, but forgot to +write `EXPECT_CALL(foo, Bar()).Times(0)`? While one can argue that it's the +user's fault, gMock tries to be nice and prints you a note. + +So, when you see the message and believe that there shouldn't be any +uninteresting calls, you should investigate what's going on. To make your life +easier, gMock dumps the stack trace when an uninteresting call is encountered. +From that you can figure out which mock function it is, and how it is called. + +### I want to define a custom action. Should I use Invoke() or implement the ActionInterface interface? + +Either way is fine - you want to choose the one that's more convenient for your +circumstance. + +Usually, if your action is for a particular function type, defining it using +`Invoke()` should be easier; if your action can be used in functions of +different types (e.g. if you are defining `Return(*value*)`), +`MakePolymorphicAction()` is easiest. Sometimes you want precise control on what +types of functions the action can be used in, and implementing `ActionInterface` +is the way to go here. See the implementation of `Return()` in `gmock-actions.h` +for an example. + +### I use SetArgPointee() in WillOnce(), but gcc complains about "conflicting return type specified". What does it mean? + +You got this error as gMock has no idea what value it should return when the +mock method is called. `SetArgPointee()` says what the side effect is, but +doesn't say what the return value should be. You need `DoAll()` to chain a +`SetArgPointee()` with a `Return()` that provides a value appropriate to the API +being mocked. + +See this [recipe](gmock_cook_book.md#mocking-side-effects) for more details and +an example. + +### I have a huge mock class, and Microsoft Visual C++ runs out of memory when compiling it. What can I do? + +We've noticed that when the `/clr` compiler flag is used, Visual C++ uses 5~6 +times as much memory when compiling a mock class. We suggest to avoid `/clr` +when compiling native C++ mocks. diff --git a/third_party/googletest/docs/gmock_for_dummies.md b/third_party/googletest/docs/gmock_for_dummies.md new file mode 100644 index 0000000..43f907a --- /dev/null +++ b/third_party/googletest/docs/gmock_for_dummies.md @@ -0,0 +1,700 @@ +# gMock for Dummies + +## What Is gMock? + +When you write a prototype or test, often it's not feasible or wise to rely on +real objects entirely. A **mock object** implements the same interface as a real +object (so it can be used as one), but lets you specify at run time how it will +be used and what it should do (which methods will be called? in which order? how +many times? with what arguments? what will they return? etc). + +It is easy to confuse the term *fake objects* with mock objects. Fakes and mocks +actually mean very different things in the Test-Driven Development (TDD) +community: + +* **Fake** objects have working implementations, but usually take some + shortcut (perhaps to make the operations less expensive), which makes them + not suitable for production. An in-memory file system would be an example of + a fake. +* **Mocks** are objects pre-programmed with *expectations*, which form a + specification of the calls they are expected to receive. + +If all this seems too abstract for you, don't worry - the most important thing +to remember is that a mock allows you to check the *interaction* between itself +and code that uses it. The difference between fakes and mocks shall become much +clearer once you start to use mocks. + +**gMock** is a library (sometimes we also call it a "framework" to make it sound +cool) for creating mock classes and using them. It does to C++ what +jMock/EasyMock does to Java (well, more or less). + +When using gMock, + +1. first, you use some simple macros to describe the interface you want to + mock, and they will expand to the implementation of your mock class; +2. next, you create some mock objects and specify its expectations and behavior + using an intuitive syntax; +3. then you exercise code that uses the mock objects. gMock will catch any + violation to the expectations as soon as it arises. + +## Why gMock? + +While mock objects help you remove unnecessary dependencies in tests and make +them fast and reliable, using mocks manually in C++ is *hard*: + +* Someone has to implement the mocks. The job is usually tedious and + error-prone. No wonder people go great distance to avoid it. +* The quality of those manually written mocks is a bit, uh, unpredictable. You + may see some really polished ones, but you may also see some that were + hacked up in a hurry and have all sorts of ad hoc restrictions. +* The knowledge you gained from using one mock doesn't transfer to the next + one. + +In contrast, Java and Python programmers have some fine mock frameworks (jMock, +EasyMock, etc), which automate the creation of mocks. As a result, mocking is a +proven effective technique and widely adopted practice in those communities. +Having the right tool absolutely makes the difference. + +gMock was built to help C++ programmers. It was inspired by jMock and EasyMock, +but designed with C++'s specifics in mind. It is your friend if any of the +following problems is bothering you: + +* You are stuck with a sub-optimal design and wish you had done more + prototyping before it was too late, but prototyping in C++ is by no means + "rapid". +* Your tests are slow as they depend on too many libraries or use expensive + resources (e.g. a database). +* Your tests are brittle as some resources they use are unreliable (e.g. the + network). +* You want to test how your code handles a failure (e.g. a file checksum + error), but it's not easy to cause one. +* You need to make sure that your module interacts with other modules in the + right way, but it's hard to observe the interaction; therefore you resort to + observing the side effects at the end of the action, but it's awkward at + best. +* You want to "mock out" your dependencies, except that they don't have mock + implementations yet; and, frankly, you aren't thrilled by some of those + hand-written mocks. + +We encourage you to use gMock as + +* a *design* tool, for it lets you experiment with your interface design early + and often. More iterations lead to better designs! +* a *testing* tool to cut your tests' outbound dependencies and probe the + interaction between your module and its collaborators. + +## Getting Started + +gMock is bundled with googletest. + +## A Case for Mock Turtles + +Let's look at an example. Suppose you are developing a graphics program that +relies on a [LOGO](http://en.wikipedia.org/wiki/Logo_programming_language)-like +API for drawing. How would you test that it does the right thing? Well, you can +run it and compare the screen with a golden screen snapshot, but let's admit it: +tests like this are expensive to run and fragile (What if you just upgraded to a +shiny new graphics card that has better anti-aliasing? Suddenly you have to +update all your golden images.). It would be too painful if all your tests are +like this. Fortunately, you learned about +[Dependency Injection](http://en.wikipedia.org/wiki/Dependency_injection) and know the right thing +to do: instead of having your application talk to the system API directly, wrap +the API in an interface (say, `Turtle`) and code to that interface: + +```cpp +class Turtle { + ... + virtual ~Turtle() {} + virtual void PenUp() = 0; + virtual void PenDown() = 0; + virtual void Forward(int distance) = 0; + virtual void Turn(int degrees) = 0; + virtual void GoTo(int x, int y) = 0; + virtual int GetX() const = 0; + virtual int GetY() const = 0; +}; +``` + +(Note that the destructor of `Turtle` **must** be virtual, as is the case for +**all** classes you intend to inherit from - otherwise the destructor of the +derived class will not be called when you delete an object through a base +pointer, and you'll get corrupted program states like memory leaks.) + +You can control whether the turtle's movement will leave a trace using `PenUp()` +and `PenDown()`, and control its movement using `Forward()`, `Turn()`, and +`GoTo()`. Finally, `GetX()` and `GetY()` tell you the current position of the +turtle. + +Your program will normally use a real implementation of this interface. In +tests, you can use a mock implementation instead. This allows you to easily +check what drawing primitives your program is calling, with what arguments, and +in which order. Tests written this way are much more robust (they won't break +because your new machine does anti-aliasing differently), easier to read and +maintain (the intent of a test is expressed in the code, not in some binary +images), and run *much, much faster*. + +## Writing the Mock Class + +If you are lucky, the mocks you need to use have already been implemented by +some nice people. If, however, you find yourself in the position to write a mock +class, relax - gMock turns this task into a fun game! (Well, almost.) + +### How to Define It + +Using the `Turtle` interface as example, here are the simple steps you need to +follow: + +* Derive a class `MockTurtle` from `Turtle`. +* Take a *virtual* function of `Turtle` (while it's possible to + [mock non-virtual methods using templates](gmock_cook_book.md#MockingNonVirtualMethods), + it's much more involved). +* In the `public:` section of the child class, write `MOCK_METHOD();` +* Now comes the fun part: you take the function signature, cut-and-paste it + into the macro, and add two commas - one between the return type and the + name, another between the name and the argument list. +* If you're mocking a const method, add a 4th parameter containing `(const)` + (the parentheses are required). +* Since you're overriding a virtual method, we suggest adding the `override` + keyword. For const methods the 4th parameter becomes `(const, override)`, + for non-const methods just `(override)`. This isn't mandatory. +* Repeat until all virtual functions you want to mock are done. (It goes + without saying that *all* pure virtual methods in your abstract class must + be either mocked or overridden.) + +After the process, you should have something like: + +```cpp +#include // Brings in gMock. + +class MockTurtle : public Turtle { + public: + ... + MOCK_METHOD(void, PenUp, (), (override)); + MOCK_METHOD(void, PenDown, (), (override)); + MOCK_METHOD(void, Forward, (int distance), (override)); + MOCK_METHOD(void, Turn, (int degrees), (override)); + MOCK_METHOD(void, GoTo, (int x, int y), (override)); + MOCK_METHOD(int, GetX, (), (const, override)); + MOCK_METHOD(int, GetY, (), (const, override)); +}; +``` + +You don't need to define these mock methods somewhere else - the `MOCK_METHOD` +macro will generate the definitions for you. It's that simple! + +### Where to Put It + +When you define a mock class, you need to decide where to put its definition. +Some people put it in a `_test.cc`. This is fine when the interface being mocked +(say, `Foo`) is owned by the same person or team. Otherwise, when the owner of +`Foo` changes it, your test could break. (You can't really expect `Foo`'s +maintainer to fix every test that uses `Foo`, can you?) + +Generally, you should not mock classes you don't own. If you must mock such a +class owned by others, define the mock class in `Foo`'s Bazel package (usually +the same directory or a `testing` sub-directory), and put it in a `.h` and a +`cc_library` with `testonly=True`. Then everyone can reference them from their +tests. If `Foo` ever changes, there is only one copy of `MockFoo` to change, and +only tests that depend on the changed methods need to be fixed. + +Another way to do it: you can introduce a thin layer `FooAdaptor` on top of +`Foo` and code to this new interface. Since you own `FooAdaptor`, you can absorb +changes in `Foo` much more easily. While this is more work initially, carefully +choosing the adaptor interface can make your code easier to write and more +readable (a net win in the long run), as you can choose `FooAdaptor` to fit your +specific domain much better than `Foo` does. + +## Using Mocks in Tests + +Once you have a mock class, using it is easy. The typical work flow is: + +1. Import the gMock names from the `testing` namespace such that you can use + them unqualified (You only have to do it once per file). Remember that + namespaces are a good idea. +2. Create some mock objects. +3. Specify your expectations on them (How many times will a method be called? + With what arguments? What should it do? etc.). +4. Exercise some code that uses the mocks; optionally, check the result using + googletest assertions. If a mock method is called more than expected or with + wrong arguments, you'll get an error immediately. +5. When a mock is destructed, gMock will automatically check whether all + expectations on it have been satisfied. + +Here's an example: + +```cpp +#include "path/to/mock-turtle.h" +#include +#include + +using ::testing::AtLeast; // #1 + +TEST(PainterTest, CanDrawSomething) { + MockTurtle turtle; // #2 + EXPECT_CALL(turtle, PenDown()) // #3 + .Times(AtLeast(1)); + + Painter painter(&turtle); // #4 + + EXPECT_TRUE(painter.DrawCircle(0, 0, 10)); // #5 +} +``` + +As you might have guessed, this test checks that `PenDown()` is called at least +once. If the `painter` object didn't call this method, your test will fail with +a message like this: + +```text +path/to/my_test.cc:119: Failure +Actual function call count doesn't match this expectation: +Actually: never called; +Expected: called at least once. +Stack trace: +... +``` + +**Tip 1:** If you run the test from an Emacs buffer, you can hit `` on +the line number to jump right to the failed expectation. + +**Tip 2:** If your mock objects are never deleted, the final verification won't +happen. Therefore it's a good idea to turn on the heap checker in your tests +when you allocate mocks on the heap. You get that automatically if you use the +`gtest_main` library already. + +**Important note:** gMock requires expectations to be set **before** the mock +functions are called, otherwise the behavior is **undefined**. Do not alternate +between calls to `EXPECT_CALL()` and calls to the mock functions, and do not set +any expectations on a mock after passing the mock to an API. + +This means `EXPECT_CALL()` should be read as expecting that a call will occur +*in the future*, not that a call has occurred. Why does gMock work like that? +Well, specifying the expectation beforehand allows gMock to report a violation +as soon as it rises, when the context (stack trace, etc) is still available. +This makes debugging much easier. + +Admittedly, this test is contrived and doesn't do much. You can easily achieve +the same effect without using gMock. However, as we shall reveal soon, gMock +allows you to do *so much more* with the mocks. + +## Setting Expectations + +The key to using a mock object successfully is to set the *right expectations* +on it. If you set the expectations too strict, your test will fail as the result +of unrelated changes. If you set them too loose, bugs can slip through. You want +to do it just right such that your test can catch exactly the kind of bugs you +intend it to catch. gMock provides the necessary means for you to do it "just +right." + +### General Syntax + +In gMock we use the `EXPECT_CALL()` macro to set an expectation on a mock +method. The general syntax is: + +```cpp +EXPECT_CALL(mock_object, method(matchers)) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +The macro has two arguments: first the mock object, and then the method and its +arguments. Note that the two are separated by a comma (`,`), not a period (`.`). +(Why using a comma? The answer is that it was necessary for technical reasons.) +If the method is not overloaded, the macro can also be called without matchers: + +```cpp +EXPECT_CALL(mock_object, non-overloaded-method) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +This syntax allows the test writer to specify "called with any arguments" +without explicitly specifying the number or types of arguments. To avoid +unintended ambiguity, this syntax may only be used for methods that are not +overloaded. + +Either form of the macro can be followed by some optional *clauses* that provide +more information about the expectation. We'll discuss how each clause works in +the coming sections. + +This syntax is designed to make an expectation read like English. For example, +you can probably guess that + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .Times(5) + .WillOnce(Return(100)) + .WillOnce(Return(150)) + .WillRepeatedly(Return(200)); +``` + +says that the `turtle` object's `GetX()` method will be called five times, it +will return 100 the first time, 150 the second time, and then 200 every time. +Some people like to call this style of syntax a Domain-Specific Language (DSL). + +{: .callout .note} +**Note:** Why do we use a macro to do this? Well it serves two purposes: first +it makes expectations easily identifiable (either by `grep` or by a human +reader), and second it allows gMock to include the source file location of a +failed expectation in messages, making debugging easier. + +### Matchers: What Arguments Do We Expect? + +When a mock function takes arguments, we may specify what arguments we are +expecting, for example: + +```cpp +// Expects the turtle to move forward by 100 units. +EXPECT_CALL(turtle, Forward(100)); +``` + +Oftentimes you do not want to be too specific. Remember that talk about tests +being too rigid? Over specification leads to brittle tests and obscures the +intent of tests. Therefore we encourage you to specify only what's necessary—no +more, no less. If you aren't interested in the value of an argument, write `_` +as the argument, which means "anything goes": + +```cpp +using ::testing::_; +... +// Expects that the turtle jumps to somewhere on the x=50 line. +EXPECT_CALL(turtle, GoTo(50, _)); +``` + +`_` is an instance of what we call **matchers**. A matcher is like a predicate +and can test whether an argument is what we'd expect. You can use a matcher +inside `EXPECT_CALL()` wherever a function argument is expected. `_` is a +convenient way of saying "any value". + +In the above examples, `100` and `50` are also matchers; implicitly, they are +the same as `Eq(100)` and `Eq(50)`, which specify that the argument must be +equal (using `operator==`) to the matcher argument. There are many +[built-in matchers](reference/matchers.md) for common types (as well as +[custom matchers](gmock_cook_book.md#NewMatchers)); for example: + +```cpp +using ::testing::Ge; +... +// Expects the turtle moves forward by at least 100. +EXPECT_CALL(turtle, Forward(Ge(100))); +``` + +If you don't care about *any* arguments, rather than specify `_` for each of +them you may instead omit the parameter list: + +```cpp +// Expects the turtle to move forward. +EXPECT_CALL(turtle, Forward); +// Expects the turtle to jump somewhere. +EXPECT_CALL(turtle, GoTo); +``` + +This works for all non-overloaded methods; if a method is overloaded, you need +to help gMock resolve which overload is expected by specifying the number of +arguments and possibly also the +[types of the arguments](gmock_cook_book.md#SelectOverload). + +### Cardinalities: How Many Times Will It Be Called? + +The first clause we can specify following an `EXPECT_CALL()` is `Times()`. We +call its argument a **cardinality** as it tells *how many times* the call should +occur. It allows us to repeat an expectation many times without actually writing +it as many times. More importantly, a cardinality can be "fuzzy", just like a +matcher can be. This allows a user to express the intent of a test exactly. + +An interesting special case is when we say `Times(0)`. You may have guessed - it +means that the function shouldn't be called with the given arguments at all, and +gMock will report a googletest failure whenever the function is (wrongfully) +called. + +We've seen `AtLeast(n)` as an example of fuzzy cardinalities earlier. For the +list of built-in cardinalities you can use, see +[here](gmock_cheat_sheet.md#CardinalityList). + +The `Times()` clause can be omitted. **If you omit `Times()`, gMock will infer +the cardinality for you.** The rules are easy to remember: + +* If **neither** `WillOnce()` **nor** `WillRepeatedly()` is in the + `EXPECT_CALL()`, the inferred cardinality is `Times(1)`. +* If there are *n* `WillOnce()`'s but **no** `WillRepeatedly()`, where *n* >= + 1, the cardinality is `Times(n)`. +* If there are *n* `WillOnce()`'s and **one** `WillRepeatedly()`, where *n* >= + 0, the cardinality is `Times(AtLeast(n))`. + +**Quick quiz:** what do you think will happen if a function is expected to be +called twice but actually called four times? + +### Actions: What Should It Do? + +Remember that a mock object doesn't really have a working implementation? We as +users have to tell it what to do when a method is invoked. This is easy in +gMock. + +First, if the return type of a mock function is a built-in type or a pointer, +the function has a **default action** (a `void` function will just return, a +`bool` function will return `false`, and other functions will return 0). In +addition, in C++ 11 and above, a mock function whose return type is +default-constructible (i.e. has a default constructor) has a default action of +returning a default-constructed value. If you don't say anything, this behavior +will be used. + +Second, if a mock function doesn't have a default action, or the default action +doesn't suit you, you can specify the action to be taken each time the +expectation matches using a series of `WillOnce()` clauses followed by an +optional `WillRepeatedly()`. For example, + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillOnce(Return(300)); +``` + +says that `turtle.GetX()` will be called *exactly three times* (gMock inferred +this from how many `WillOnce()` clauses we've written, since we didn't +explicitly write `Times()`), and will return 100, 200, and 300 respectively. + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillRepeatedly(Return(300)); +``` + +says that `turtle.GetY()` will be called *at least twice* (gMock knows this as +we've written two `WillOnce()` clauses and a `WillRepeatedly()` while having no +explicit `Times()`), will return 100 and 200 respectively the first two times, +and 300 from the third time on. + +Of course, if you explicitly write a `Times()`, gMock will not try to infer the +cardinality itself. What if the number you specified is larger than there are +`WillOnce()` clauses? Well, after all `WillOnce()`s are used up, gMock will do +the *default* action for the function every time (unless, of course, you have a +`WillRepeatedly()`.). + +What can we do inside `WillOnce()` besides `Return()`? You can return a +reference using `ReturnRef(`*`variable`*`)`, or invoke a pre-defined function, +among [others](gmock_cook_book.md#using-actions). + +**Important note:** The `EXPECT_CALL()` statement evaluates the action clause +only once, even though the action may be performed many times. Therefore you +must be careful about side effects. The following may not do what you want: + +```cpp +using ::testing::Return; +... +int n = 100; +EXPECT_CALL(turtle, GetX()) + .Times(4) + .WillRepeatedly(Return(n++)); +``` + +Instead of returning 100, 101, 102, ..., consecutively, this mock function will +always return 100 as `n++` is only evaluated once. Similarly, `Return(new Foo)` +will create a new `Foo` object when the `EXPECT_CALL()` is executed, and will +return the same pointer every time. If you want the side effect to happen every +time, you need to define a custom action, which we'll teach in the +[cook book](gmock_cook_book.md). + +Time for another quiz! What do you think the following means? + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .Times(4) + .WillOnce(Return(100)); +``` + +Obviously `turtle.GetY()` is expected to be called four times. But if you think +it will return 100 every time, think twice! Remember that one `WillOnce()` +clause will be consumed each time the function is invoked and the default action +will be taken afterwards. So the right answer is that `turtle.GetY()` will +return 100 the first time, but **return 0 from the second time on**, as +returning 0 is the default action for `int` functions. + +### Using Multiple Expectations {#MultiExpectations} + +So far we've only shown examples where you have a single expectation. More +realistically, you'll specify expectations on multiple mock methods which may be +from multiple mock objects. + +By default, when a mock method is invoked, gMock will search the expectations in +the **reverse order** they are defined, and stop when an active expectation that +matches the arguments is found (you can think of it as "newer rules override +older ones."). If the matching expectation cannot take any more calls, you will +get an upper-bound-violated failure. Here's an example: + +```cpp +using ::testing::_; +... +EXPECT_CALL(turtle, Forward(_)); // #1 +EXPECT_CALL(turtle, Forward(10)) // #2 + .Times(2); +``` + +If `Forward(10)` is called three times in a row, the third time it will be an +error, as the last matching expectation (#2) has been saturated. If, however, +the third `Forward(10)` call is replaced by `Forward(20)`, then it would be OK, +as now #1 will be the matching expectation. + +{: .callout .note} +**Note:** Why does gMock search for a match in the *reverse* order of the +expectations? The reason is that this allows a user to set up the default +expectations in a mock object's constructor or the test fixture's set-up phase +and then customize the mock by writing more specific expectations in the test +body. So, if you have two expectations on the same method, you want to put the +one with more specific matchers **after** the other, or the more specific rule +would be shadowed by the more general one that comes after it. + +{: .callout .tip} +**Tip:** It is very common to start with a catch-all expectation for a method +and `Times(AnyNumber())` (omitting arguments, or with `_` for all arguments, if +overloaded). This makes any calls to the method expected. This is not necessary +for methods that are not mentioned at all (these are "uninteresting"), but is +useful for methods that have some expectations, but for which other calls are +ok. See +[Understanding Uninteresting vs Unexpected Calls](gmock_cook_book.md#uninteresting-vs-unexpected). + +### Ordered vs Unordered Calls {#OrderedCalls} + +By default, an expectation can match a call even though an earlier expectation +hasn't been satisfied. In other words, the calls don't have to occur in the +order the expectations are specified. + +Sometimes, you may want all the expected calls to occur in a strict order. To +say this in gMock is easy: + +```cpp +using ::testing::InSequence; +... +TEST(FooTest, DrawsLineSegment) { + ... + { + InSequence seq; + + EXPECT_CALL(turtle, PenDown()); + EXPECT_CALL(turtle, Forward(100)); + EXPECT_CALL(turtle, PenUp()); + } + Foo(); +} +``` + +By creating an object of type `InSequence`, all expectations in its scope are +put into a *sequence* and have to occur *sequentially*. Since we are just +relying on the constructor and destructor of this object to do the actual work, +its name is really irrelevant. + +In this example, we test that `Foo()` calls the three expected functions in the +order as written. If a call is made out-of-order, it will be an error. + +(What if you care about the relative order of some of the calls, but not all of +them? Can you specify an arbitrary partial order? The answer is ... yes! The +details can be found [here](gmock_cook_book.md#OrderedCalls).) + +### All Expectations Are Sticky (Unless Said Otherwise) {#StickyExpectations} + +Now let's do a quick quiz to see how well you can use this mock stuff already. +How would you test that the turtle is asked to go to the origin *exactly twice* +(you want to ignore any other instructions it receives)? + +After you've come up with your answer, take a look at ours and compare notes +(solve it yourself first - don't cheat!): + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(turtle, GoTo(_, _)) // #1 + .Times(AnyNumber()); +EXPECT_CALL(turtle, GoTo(0, 0)) // #2 + .Times(2); +``` + +Suppose `turtle.GoTo(0, 0)` is called three times. In the third time, gMock will +see that the arguments match expectation #2 (remember that we always pick the +last matching expectation). Now, since we said that there should be only two +such calls, gMock will report an error immediately. This is basically what we've +told you in the [Using Multiple Expectations](#MultiExpectations) section above. + +This example shows that **expectations in gMock are "sticky" by default**, in +the sense that they remain active even after we have reached their invocation +upper bounds. This is an important rule to remember, as it affects the meaning +of the spec, and is **different** to how it's done in many other mocking +frameworks (Why'd we do that? Because we think our rule makes the common cases +easier to express and understand.). + +Simple? Let's see if you've really understood it: what does the following code +say? + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)); +} +``` + +If you think it says that `turtle.GetX()` will be called `n` times and will +return 10, 20, 30, ..., consecutively, think twice! The problem is that, as we +said, expectations are sticky. So, the second time `turtle.GetX()` is called, +the last (latest) `EXPECT_CALL()` statement will match, and will immediately +lead to an "upper bound violated" error - this piece of code is not very useful! + +One correct way of saying that `turtle.GetX()` will return 10, 20, 30, ..., is +to explicitly say that the expectations are *not* sticky. In other words, they +should *retire* as soon as they are saturated: + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); +} +``` + +And, there's a better way to do it: in this case, we expect the calls to occur +in a specific order, and we line up the actions to match the order. Since the +order is important here, we should make it explicit using a sequence: + +```cpp +using ::testing::InSequence; +using ::testing::Return; +... +{ + InSequence s; + + for (int i = 1; i <= n; i++) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); + } +} +``` + +By the way, the other situation where an expectation may *not* be sticky is when +it's in a sequence - as soon as another expectation that comes after it in the +sequence has been used, it automatically retires (and will never be used to +match any call). + +### Uninteresting Calls + +A mock object may have many methods, and not all of them are that interesting. +For example, in some tests we may not care about how many times `GetX()` and +`GetY()` get called. + +In gMock, if you are not interested in a method, just don't say anything about +it. If a call to this method occurs, you'll see a warning in the test output, +but it won't be a failure. This is called "naggy" behavior; to change, see +[The Nice, the Strict, and the Naggy](gmock_cook_book.md#NiceStrictNaggy). diff --git a/third_party/googletest/docs/index.md b/third_party/googletest/docs/index.md new file mode 100644 index 0000000..b162c74 --- /dev/null +++ b/third_party/googletest/docs/index.md @@ -0,0 +1,22 @@ +# GoogleTest User's Guide + +## Welcome to GoogleTest! + +GoogleTest is Google's C++ testing and mocking framework. This user's guide has +the following contents: + +* [GoogleTest Primer](primer.md) - Teaches you how to write simple tests using + GoogleTest. Read this first if you are new to GoogleTest. +* [GoogleTest Advanced](advanced.md) - Read this when you've finished the + Primer and want to utilize GoogleTest to its full potential. +* [GoogleTest Samples](samples.md) - Describes some GoogleTest samples. +* [GoogleTest FAQ](faq.md) - Have a question? Want some tips? Check here + first. +* [Mocking for Dummies](gmock_for_dummies.md) - Teaches you how to create mock + objects and use them in tests. +* [Mocking Cookbook](gmock_cook_book.md) - Includes tips and approaches to + common mocking use cases. +* [Mocking Cheat Sheet](gmock_cheat_sheet.md) - A handy reference for + matchers, actions, invariants, and more. +* [Mocking FAQ](gmock_faq.md) - Contains answers to some mocking-specific + questions. diff --git a/third_party/googletest/docs/pkgconfig.md b/third_party/googletest/docs/pkgconfig.md new file mode 100644 index 0000000..bf05d59 --- /dev/null +++ b/third_party/googletest/docs/pkgconfig.md @@ -0,0 +1,144 @@ +## Using GoogleTest from various build systems + +GoogleTest comes with pkg-config files that can be used to determine all +necessary flags for compiling and linking to GoogleTest (and GoogleMock). +Pkg-config is a standardised plain-text format containing + +* the includedir (-I) path +* necessary macro (-D) definitions +* further required flags (-pthread) +* the library (-L) path +* the library (-l) to link to + +All current build systems support pkg-config in one way or another. For all +examples here we assume you want to compile the sample +`samples/sample3_unittest.cc`. + +### CMake + +Using `pkg-config` in CMake is fairly easy: + +```cmake +find_package(PkgConfig) +pkg_search_module(GTEST REQUIRED gtest_main) + +add_executable(testapp) +target_sources(testapp PRIVATE samples/sample3_unittest.cc) +target_link_libraries(testapp PRIVATE ${GTEST_LDFLAGS}) +target_compile_options(testapp PRIVATE ${GTEST_CFLAGS}) + +enable_testing() +add_test(first_and_only_test testapp) +``` + +It is generally recommended that you use `target_compile_options` + `_CFLAGS` +over `target_include_directories` + `_INCLUDE_DIRS` as the former includes not +just -I flags (GoogleTest might require a macro indicating to internal headers +that all libraries have been compiled with threading enabled. In addition, +GoogleTest might also require `-pthread` in the compiling step, and as such +splitting the pkg-config `Cflags` variable into include dirs and macros for +`target_compile_definitions()` might still miss this). The same recommendation +goes for using `_LDFLAGS` over the more commonplace `_LIBRARIES`, which happens +to discard `-L` flags and `-pthread`. + +### Help! pkg-config can't find GoogleTest! + +Let's say you have a `CMakeLists.txt` along the lines of the one in this +tutorial and you try to run `cmake`. It is very possible that you get a failure +along the lines of: + +``` +-- Checking for one of the modules 'gtest_main' +CMake Error at /usr/share/cmake/Modules/FindPkgConfig.cmake:640 (message): + None of the required 'gtest_main' found +``` + +These failures are common if you installed GoogleTest yourself and have not +sourced it from a distro or other package manager. If so, you need to tell +pkg-config where it can find the `.pc` files containing the information. Say you +installed GoogleTest to `/usr/local`, then it might be that the `.pc` files are +installed under `/usr/local/lib64/pkgconfig`. If you set + +``` +export PKG_CONFIG_PATH=/usr/local/lib64/pkgconfig +``` + +pkg-config will also try to look in `PKG_CONFIG_PATH` to find `gtest_main.pc`. + +### Using pkg-config in a cross-compilation setting + +Pkg-config can be used in a cross-compilation setting too. To do this, let's +assume the final prefix of the cross-compiled installation will be `/usr`, and +your sysroot is `/home/MYUSER/sysroot`. Configure and install GTest using + +``` +mkdir build && cmake -DCMAKE_INSTALL_PREFIX=/usr .. +``` + +Install into the sysroot using `DESTDIR`: + +``` +make -j install DESTDIR=/home/MYUSER/sysroot +``` + +Before we continue, it is recommended to **always** define the following two +variables for pkg-config in a cross-compilation setting: + +``` +export PKG_CONFIG_ALLOW_SYSTEM_CFLAGS=yes +export PKG_CONFIG_ALLOW_SYSTEM_LIBS=yes +``` + +otherwise `pkg-config` will filter `-I` and `-L` flags against standard prefixes +such as `/usr` (see https://bugs.freedesktop.org/show_bug.cgi?id=28264#c3 for +reasons why this stripping needs to occur usually). + +If you look at the generated pkg-config file, it will look something like + +``` +libdir=/usr/lib64 +includedir=/usr/include + +Name: gtest +Description: GoogleTest (without main() function) +Version: 1.11.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest -lpthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -lpthread +``` + +Notice that the sysroot is not included in `libdir` and `includedir`! If you try +to run `pkg-config` with the correct +`PKG_CONFIG_LIBDIR=/home/MYUSER/sysroot/usr/lib64/pkgconfig` against this `.pc` +file, you will get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/usr/include +$ pkg-config --libs gtest +-L/usr/lib64 -lgtest -lpthread +``` + +which is obviously wrong and points to the `CBUILD` and not `CHOST` root. In +order to use this in a cross-compilation setting, we need to tell pkg-config to +inject the actual sysroot into `-I` and `-L` variables. Let us now tell +pkg-config about the actual sysroot + +``` +export PKG_CONFIG_DIR= +export PKG_CONFIG_SYSROOT_DIR=/home/MYUSER/sysroot +export PKG_CONFIG_LIBDIR=${PKG_CONFIG_SYSROOT_DIR}/usr/lib64/pkgconfig +``` + +and running `pkg-config` again we get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/home/MYUSER/sysroot/usr/include +$ pkg-config --libs gtest +-L/home/MYUSER/sysroot/usr/lib64 -lgtest -lpthread +``` + +which contains the correct sysroot now. For a more comprehensive guide to also +including `${CHOST}` in build system calls, see the excellent tutorial by Diego +Elio Pettenò: diff --git a/third_party/googletest/docs/platforms.md b/third_party/googletest/docs/platforms.md new file mode 100644 index 0000000..d35a7be --- /dev/null +++ b/third_party/googletest/docs/platforms.md @@ -0,0 +1,8 @@ +# Supported Platforms + +GoogleTest follows Google's +[Foundational C++ Support Policy](https://opensource.google/documentation/policies/cplusplus-support). +See +[this table](https://github.com/google/oss-policies-info/blob/main/foundational-cxx-support-matrix.md) +for a list of currently supported versions compilers, platforms, and build +tools. diff --git a/third_party/googletest/docs/primer.md b/third_party/googletest/docs/primer.md new file mode 100644 index 0000000..f2a97a7 --- /dev/null +++ b/third_party/googletest/docs/primer.md @@ -0,0 +1,483 @@ +# GoogleTest Primer + +## Introduction: Why GoogleTest? + +*GoogleTest* helps you write better C++ tests. + +GoogleTest is a testing framework developed by the Testing Technology team with +Google's specific requirements and constraints in mind. Whether you work on +Linux, Windows, or a Mac, if you write C++ code, GoogleTest can help you. And it +supports *any* kind of tests, not just unit tests. + +So what makes a good test, and how does GoogleTest fit in? We believe: + +1. Tests should be *independent* and *repeatable*. It's a pain to debug a test + that succeeds or fails as a result of other tests. GoogleTest isolates the + tests by running each of them on a different object. When a test fails, + GoogleTest allows you to run it in isolation for quick debugging. +2. Tests should be well *organized* and reflect the structure of the tested + code. GoogleTest groups related tests into test suites that can share data + and subroutines. This common pattern is easy to recognize and makes tests + easy to maintain. Such consistency is especially helpful when people switch + projects and start to work on a new code base. +3. Tests should be *portable* and *reusable*. Google has a lot of code that is + platform-neutral; its tests should also be platform-neutral. GoogleTest + works on different OSes, with different compilers, with or without + exceptions, so GoogleTest tests can work with a variety of configurations. +4. When tests fail, they should provide as much *information* about the problem + as possible. GoogleTest doesn't stop at the first test failure. Instead, it + only stops the current test and continues with the next. You can also set up + tests that report non-fatal failures after which the current test continues. + Thus, you can detect and fix multiple bugs in a single run-edit-compile + cycle. +5. The testing framework should liberate test writers from housekeeping chores + and let them focus on the test *content*. GoogleTest automatically keeps + track of all tests defined, and doesn't require the user to enumerate them + in order to run them. +6. Tests should be *fast*. With GoogleTest, you can reuse shared resources + across tests and pay for the set-up/tear-down only once, without making + tests depend on each other. + +Since GoogleTest is based on the popular xUnit architecture, you'll feel right +at home if you've used JUnit or PyUnit before. If not, it will take you about 10 +minutes to learn the basics and get started. So let's go! + +## Beware of the Nomenclature + +{: .callout .note} +*Note:* There might be some confusion arising from different definitions of the +terms *Test*, *Test Case* and *Test Suite*, so beware of misunderstanding these. + +Historically, GoogleTest started to use the term *Test Case* for grouping +related tests, whereas current publications, including International Software +Testing Qualifications Board ([ISTQB](http://www.istqb.org/)) materials and +various textbooks on software quality, use the term +*[Test Suite][istqb test suite]* for this. + +The related term *Test*, as it is used in GoogleTest, corresponds to the term +*[Test Case][istqb test case]* of ISTQB and others. + +The term *Test* is commonly of broad enough sense, including ISTQB's definition +of *Test Case*, so it's not much of a problem here. But the term *Test Case* as +was used in Google Test is of contradictory sense and thus confusing. + +GoogleTest recently started replacing the term *Test Case* with *Test Suite*. +The preferred API is *TestSuite*. The older TestCase API is being slowly +deprecated and refactored away. + +So please be aware of the different definitions of the terms: + + +Meaning | GoogleTest Term | [ISTQB](http://www.istqb.org/) Term +:----------------------------------------------------------------------------------- | :---------------------- | :---------------------------------- +Exercise a particular program path with specific input values and verify the results | [TEST()](#simple-tests) | [Test Case][istqb test case] + + +[istqb test case]: http://glossary.istqb.org/en/search/test%20case +[istqb test suite]: http://glossary.istqb.org/en/search/test%20suite + +## Basic Concepts + +When using GoogleTest, you start by writing *assertions*, which are statements +that check whether a condition is true. An assertion's result can be *success*, +*nonfatal failure*, or *fatal failure*. If a fatal failure occurs, it aborts the +current function; otherwise the program continues normally. + +*Tests* use assertions to verify the tested code's behavior. If a test crashes +or has a failed assertion, then it *fails*; otherwise it *succeeds*. + +A *test suite* contains one or many tests. You should group your tests into test +suites that reflect the structure of the tested code. When multiple tests in a +test suite need to share common objects and subroutines, you can put them into a +*test fixture* class. + +A *test program* can contain multiple test suites. + +We'll now explain how to write a test program, starting at the individual +assertion level and building up to tests and test suites. + +## Assertions + +GoogleTest assertions are macros that resemble function calls. You test a class +or function by making assertions about its behavior. When an assertion fails, +GoogleTest prints the assertion's source file and line number location, along +with a failure message. You may also supply a custom failure message which will +be appended to GoogleTest's message. + +The assertions come in pairs that test the same thing but have different effects +on the current function. `ASSERT_*` versions generate fatal failures when they +fail, and **abort the current function**. `EXPECT_*` versions generate nonfatal +failures, which don't abort the current function. Usually `EXPECT_*` are +preferred, as they allow more than one failure to be reported in a test. +However, you should use `ASSERT_*` if it doesn't make sense to continue when the +assertion in question fails. + +Since a failed `ASSERT_*` returns from the current function immediately, +possibly skipping clean-up code that comes after it, it may cause a space leak. +Depending on the nature of the leak, it may or may not be worth fixing - so keep +this in mind if you get a heap checker error in addition to assertion errors. + +To provide a custom failure message, simply stream it into the macro using the +`<<` operator or a sequence of such operators. See the following example, using +the [`ASSERT_EQ` and `EXPECT_EQ`](reference/assertions.md#EXPECT_EQ) macros to +verify value equality: + +```c++ +ASSERT_EQ(x.size(), y.size()) << "Vectors x and y are of unequal length"; + +for (int i = 0; i < x.size(); ++i) { + EXPECT_EQ(x[i], y[i]) << "Vectors x and y differ at index " << i; +} +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro--in particular, C strings and `string` objects. If a wide string +(`wchar_t*`, `TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is +streamed to an assertion, it will be translated to UTF-8 when printed. + +GoogleTest provides a collection of assertions for verifying the behavior of +your code in various ways. You can check Boolean conditions, compare values +based on relational operators, verify string values, floating-point values, and +much more. There are even assertions that enable you to verify more complex +states by providing custom predicates. For the complete list of assertions +provided by GoogleTest, see the [Assertions Reference](reference/assertions.md). + +## Simple Tests + +To create a test: + +1. Use the `TEST()` macro to define and name a test function. These are + ordinary C++ functions that don't return a value. +2. In this function, along with any valid C++ statements you want to include, + use the various GoogleTest assertions to check values. +3. The test's result is determined by the assertions; if any assertion in the + test fails (either fatally or non-fatally), or if the test crashes, the + entire test fails. Otherwise, it succeeds. + +```c++ +TEST(TestSuiteName, TestName) { + ... test body ... +} +``` + +`TEST()` arguments go from general to specific. The *first* argument is the name +of the test suite, and the *second* argument is the test's name within the test +suite. Both names must be valid C++ identifiers, and they should not contain any +underscores (`_`). A test's *full name* consists of its containing test suite +and its individual name. Tests from different test suites can have the same +individual name. + +For example, let's take a simple integer function: + +```c++ +int Factorial(int n); // Returns the factorial of n +``` + +A test suite for this function might look like: + +```c++ +// Tests factorial of 0. +TEST(FactorialTest, HandlesZeroInput) { + EXPECT_EQ(Factorial(0), 1); +} + +// Tests factorial of positive numbers. +TEST(FactorialTest, HandlesPositiveInput) { + EXPECT_EQ(Factorial(1), 1); + EXPECT_EQ(Factorial(2), 2); + EXPECT_EQ(Factorial(3), 6); + EXPECT_EQ(Factorial(8), 40320); +} +``` + +GoogleTest groups the test results by test suites, so logically related tests +should be in the same test suite; in other words, the first argument to their +`TEST()` should be the same. In the above example, we have two tests, +`HandlesZeroInput` and `HandlesPositiveInput`, that belong to the same test +suite `FactorialTest`. + +When naming your test suites and tests, you should follow the same convention as +for +[naming functions and classes](https://google.github.io/styleguide/cppguide.html#Function_Names). + +**Availability**: Linux, Windows, Mac. + +## Test Fixtures: Using the Same Data Configuration for Multiple Tests {#same-data-multiple-tests} + +If you find yourself writing two or more tests that operate on similar data, you +can use a *test fixture*. This allows you to reuse the same configuration of +objects for several different tests. + +To create a fixture: + +1. Derive a class from `::testing::Test` . Start its body with `protected:`, as + we'll want to access fixture members from sub-classes. +2. Inside the class, declare any objects you plan to use. +3. If necessary, write a default constructor or `SetUp()` function to prepare + the objects for each test. A common mistake is to spell `SetUp()` as + **`Setup()`** with a small `u` - Use `override` in C++11 to make sure you + spelled it correctly. +4. If necessary, write a destructor or `TearDown()` function to release any + resources you allocated in `SetUp()` . To learn when you should use the + constructor/destructor and when you should use `SetUp()/TearDown()`, read + the [FAQ](faq.md#CtorVsSetUp). +5. If needed, define subroutines for your tests to share. + +When using a fixture, use `TEST_F()` instead of `TEST()` as it allows you to +access objects and subroutines in the test fixture: + +```c++ +TEST_F(TestFixtureClassName, TestName) { + ... test body ... +} +``` + +Unlike `TEST()`, in `TEST_F()` the first argument must be the name of the test +fixture class. (`_F` stands for "Fixture"). No test suite name is specified for +this macro. + +Unfortunately, the C++ macro system does not allow us to create a single macro +that can handle both types of tests. Using the wrong macro causes a compiler +error. + +Also, you must first define a test fixture class before using it in a +`TEST_F()`, or you'll get the compiler error "`virtual outside class +declaration`". + +For each test defined with `TEST_F()`, GoogleTest will create a *fresh* test +fixture at runtime, immediately initialize it via `SetUp()`, run the test, clean +up by calling `TearDown()`, and then delete the test fixture. Note that +different tests in the same test suite have different test fixture objects, and +GoogleTest always deletes a test fixture before it creates the next one. +GoogleTest does **not** reuse the same test fixture for multiple tests. Any +changes one test makes to the fixture do not affect other tests. + +As an example, let's write tests for a FIFO queue class named `Queue`, which has +the following interface: + +```c++ +template // E is the element type. +class Queue { + public: + Queue(); + void Enqueue(const E& element); + E* Dequeue(); // Returns NULL if the queue is empty. + size_t size() const; + ... +}; +``` + +First, define a fixture class. By convention, you should give it the name +`FooTest` where `Foo` is the class being tested. + +```c++ +class QueueTest : public ::testing::Test { + protected: + void SetUp() override { + // q0_ remains empty + q1_.Enqueue(1); + q2_.Enqueue(2); + q2_.Enqueue(3); + } + + // void TearDown() override {} + + Queue q0_; + Queue q1_; + Queue q2_; +}; +``` + +In this case, `TearDown()` is not needed since we don't have to clean up after +each test, other than what's already done by the destructor. + +Now we'll write tests using `TEST_F()` and this fixture. + +```c++ +TEST_F(QueueTest, IsEmptyInitially) { + EXPECT_EQ(q0_.size(), 0); +} + +TEST_F(QueueTest, DequeueWorks) { + int* n = q0_.Dequeue(); + EXPECT_EQ(n, nullptr); + + n = q1_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 1); + EXPECT_EQ(q1_.size(), 0); + delete n; + + n = q2_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 2); + EXPECT_EQ(q2_.size(), 1); + delete n; +} +``` + +The above uses both `ASSERT_*` and `EXPECT_*` assertions. The rule of thumb is +to use `EXPECT_*` when you want the test to continue to reveal more errors after +the assertion failure, and use `ASSERT_*` when continuing after failure doesn't +make sense. For example, the second assertion in the `Dequeue` test is +`ASSERT_NE(n, nullptr)`, as we need to dereference the pointer `n` later, which +would lead to a segfault when `n` is `NULL`. + +When these tests run, the following happens: + +1. GoogleTest constructs a `QueueTest` object (let's call it `t1`). +2. `t1.SetUp()` initializes `t1`. +3. The first test (`IsEmptyInitially`) runs on `t1`. +4. `t1.TearDown()` cleans up after the test finishes. +5. `t1` is destructed. +6. The above steps are repeated on another `QueueTest` object, this time + running the `DequeueWorks` test. + +**Availability**: Linux, Windows, Mac. + +## Invoking the Tests + +`TEST()` and `TEST_F()` implicitly register their tests with GoogleTest. So, +unlike with many other C++ testing frameworks, you don't have to re-list all +your defined tests in order to run them. + +After defining your tests, you can run them with `RUN_ALL_TESTS()`, which +returns `0` if all the tests are successful, or `1` otherwise. Note that +`RUN_ALL_TESTS()` runs *all tests* in your link unit--they can be from different +test suites, or even different source files. + +When invoked, the `RUN_ALL_TESTS()` macro: + +* Saves the state of all GoogleTest flags. + +* Creates a test fixture object for the first test. + +* Initializes it via `SetUp()`. + +* Runs the test on the fixture object. + +* Cleans up the fixture via `TearDown()`. + +* Deletes the fixture. + +* Restores the state of all GoogleTest flags. + +* Repeats the above steps for the next test, until all tests have run. + +If a fatal failure happens the subsequent steps will be skipped. + +{: .callout .important} +> IMPORTANT: You must **not** ignore the return value of `RUN_ALL_TESTS()`, or +> you will get a compiler error. The rationale for this design is that the +> automated testing service determines whether a test has passed based on its +> exit code, not on its stdout/stderr output; thus your `main()` function must +> return the value of `RUN_ALL_TESTS()`. +> +> Also, you should call `RUN_ALL_TESTS()` only **once**. Calling it more than +> once conflicts with some advanced GoogleTest features (e.g., thread-safe +> [death tests](advanced.md#death-tests)) and thus is not supported. + +**Availability**: Linux, Windows, Mac. + +## Writing the main() Function + +Most users should *not* need to write their own `main` function and instead link +with `gtest_main` (as opposed to with `gtest`), which defines a suitable entry +point. See the end of this section for details. The remainder of this section +should only apply when you need to do something custom before the tests run that +cannot be expressed within the framework of fixtures and test suites. + +If you write your own `main` function, it should return the value of +`RUN_ALL_TESTS()`. + +You can start from this boilerplate: + +```c++ +#include "this/package/foo.h" + +#include + +namespace my { +namespace project { +namespace { + +// The fixture for testing class Foo. +class FooTest : public ::testing::Test { + protected: + // You can remove any or all of the following functions if their bodies would + // be empty. + + FooTest() { + // You can do set-up work for each test here. + } + + ~FooTest() override { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + } + + void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + // Class members declared here can be used by all tests in the test suite + // for Foo. +}; + +// Tests that the Foo::Bar() method does Abc. +TEST_F(FooTest, MethodBarDoesAbc) { + const std::string input_filepath = "this/package/testdata/myinputfile.dat"; + const std::string output_filepath = "this/package/testdata/myoutputfile.dat"; + Foo f; + EXPECT_EQ(f.Bar(input_filepath, output_filepath), 0); +} + +// Tests that Foo does Xyz. +TEST_F(FooTest, DoesXyz) { + // Exercises the Xyz feature of Foo. +} + +} // namespace +} // namespace project +} // namespace my + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +``` + +The `::testing::InitGoogleTest()` function parses the command line for +GoogleTest flags, and removes all recognized flags. This allows the user to +control a test program's behavior via various flags, which we'll cover in the +[AdvancedGuide](advanced.md). You **must** call this function before calling +`RUN_ALL_TESTS()`, or the flags won't be properly initialized. + +On Windows, `InitGoogleTest()` also works with wide strings, so it can be used +in programs compiled in `UNICODE` mode as well. + +But maybe you think that writing all those `main` functions is too much work? We +agree with you completely, and that's why Google Test provides a basic +implementation of main(). If it fits your needs, then just link your test with +the `gtest_main` library and you are good to go. + +{: .callout .note} +NOTE: `ParseGUnitFlags()` is deprecated in favor of `InitGoogleTest()`. + +## Known Limitations + +* Google Test is designed to be thread-safe. The implementation is thread-safe + on systems where the `pthreads` library is available. It is currently + *unsafe* to use Google Test assertions from two threads concurrently on + other systems (e.g. Windows). In most tests this is not an issue as usually + the assertions are done in the main thread. If you want to help, you can + volunteer to implement the necessary synchronization primitives in + `gtest-port.h` for your platform. diff --git a/third_party/googletest/docs/quickstart-bazel.md b/third_party/googletest/docs/quickstart-bazel.md new file mode 100644 index 0000000..4f693db --- /dev/null +++ b/third_party/googletest/docs/quickstart-bazel.md @@ -0,0 +1,153 @@ +# Quickstart: Building with Bazel + +This tutorial aims to get you up and running with GoogleTest using the Bazel +build system. If you're using GoogleTest for the first time or need a refresher, +we recommend this tutorial as a starting point. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [Bazel](https://bazel.build/), the preferred build system used by the + GoogleTest team. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have Bazel installed, see the +[Bazel installation guide](https://bazel.build/install). + +{: .callout .note} Note: The terminal commands in this tutorial show a Unix +shell prompt, but the commands work on the Windows command line as well. + +## Set up a Bazel workspace + +A +[Bazel workspace](https://docs.bazel.build/versions/main/build-ref.html#workspace) +is a directory on your filesystem that you use to manage source files for the +software you want to build. Each workspace directory has a text file named +`WORKSPACE` which may be empty, or may contain references to external +dependencies required to build the outputs. + +First, create a directory for your workspace: + +``` +$ mkdir my_workspace && cd my_workspace +``` + +Next, you’ll create the `WORKSPACE` file to specify dependencies. A common and +recommended way to depend on GoogleTest is to use a +[Bazel external dependency](https://docs.bazel.build/versions/main/external.html) +via the +[`http_archive` rule](https://docs.bazel.build/versions/main/repo/http.html#http_archive). +To do this, in the root directory of your workspace (`my_workspace/`), create a +file named `WORKSPACE` with the following contents: + +``` +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/5ab508a01f9eb089207ee87fd547d290da39d015.zip"], + strip_prefix = "googletest-5ab508a01f9eb089207ee87fd547d290da39d015", +) +``` + +The above configuration declares a dependency on GoogleTest which is downloaded +as a ZIP archive from GitHub. In the above example, +`5ab508a01f9eb089207ee87fd547d290da39d015` is the Git commit hash of the +GoogleTest version to use; we recommend updating the hash often to point to the +latest version. Use a recent hash on the `main` branch. + +Now you're ready to build C++ code that uses GoogleTest. + +## Create and run a binary + +With your Bazel workspace set up, you can now use GoogleTest code within your +own project. + +As an example, create a file named `hello_test.cc` in your `my_workspace` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, create a file named `BUILD` in the same directory with the +following contents: + +``` +cc_test( + name = "hello_test", + size = "small", + srcs = ["hello_test.cc"], + deps = ["@com_google_googletest//:gtest_main"], +) +``` + +This `cc_test` rule declares the C++ test binary you want to build, and links to +GoogleTest (`//:gtest_main`) using the prefix you specified in the `WORKSPACE` +file (`@com_google_googletest`). For more information about Bazel `BUILD` files, +see the +[Bazel C++ Tutorial](https://docs.bazel.build/versions/main/tutorial/cpp.html). + +{: .callout .note} +NOTE: In the example below, we assume Clang or GCC and set `--cxxopt=-std=c++14` +to ensure that GoogleTest is compiled as C++14 instead of the compiler's default +setting (which could be C++11). For MSVC, the equivalent would be +`--cxxopt=/std:c++14`. See [Supported Platforms](platforms.md) for more details +on supported language versions. + +Now you can build and run your test: + +
+my_workspace$ bazel test --cxxopt=-std=c++14 --test_output=all //:hello_test
+INFO: Analyzed target //:hello_test (26 packages loaded, 362 targets configured).
+INFO: Found 1 test target...
+INFO: From Testing //:hello_test:
+==================== Test output for //:hello_test:
+Running main() from gmock_main.cc
+[==========] Running 1 test from 1 test suite.
+[----------] Global test environment set-up.
+[----------] 1 test from HelloTest
+[ RUN      ] HelloTest.BasicAssertions
+[       OK ] HelloTest.BasicAssertions (0 ms)
+[----------] 1 test from HelloTest (0 ms total)
+
+[----------] Global test environment tear-down
+[==========] 1 test from 1 test suite ran. (0 ms total)
+[  PASSED  ] 1 test.
+================================================================================
+Target //:hello_test up-to-date:
+  bazel-bin/hello_test
+INFO: Elapsed time: 4.190s, Critical Path: 3.05s
+INFO: 27 processes: 8 internal, 19 linux-sandbox.
+INFO: Build completed successfully, 27 total actions
+//:hello_test                                                     PASSED in 0.1s
+
+INFO: Build completed successfully, 27 total actions
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/third_party/googletest/docs/quickstart-cmake.md b/third_party/googletest/docs/quickstart-cmake.md new file mode 100644 index 0000000..4e422b7 --- /dev/null +++ b/third_party/googletest/docs/quickstart-cmake.md @@ -0,0 +1,157 @@ +# Quickstart: Building with CMake + +This tutorial aims to get you up and running with GoogleTest using CMake. If +you're using GoogleTest for the first time or need a refresher, we recommend +this tutorial as a starting point. If your project uses Bazel, see the +[Quickstart for Bazel](quickstart-bazel.md) instead. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [CMake](https://cmake.org/) and a compatible build tool for building the + project. + * Compatible build tools include + [Make](https://www.gnu.org/software/make/), + [Ninja](https://ninja-build.org/), and others - see + [CMake Generators](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html) + for more information. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have CMake installed, see the +[CMake installation guide](https://cmake.org/install). + +{: .callout .note} +Note: The terminal commands in this tutorial show a Unix shell prompt, but the +commands work on the Windows command line as well. + +## Set up a project + +CMake uses a file named `CMakeLists.txt` to configure the build system for a +project. You'll use this file to set up your project and declare a dependency on +GoogleTest. + +First, create a directory for your project: + +``` +$ mkdir my_project && cd my_project +``` + +Next, you'll create the `CMakeLists.txt` file and declare a dependency on +GoogleTest. There are many ways to express dependencies in the CMake ecosystem; +in this quickstart, you'll use the +[`FetchContent` CMake module](https://cmake.org/cmake/help/latest/module/FetchContent.html). +To do this, in your project directory (`my_project`), create a file named +`CMakeLists.txt` with the following contents: + +```cmake +cmake_minimum_required(VERSION 3.14) +project(my_project) + +# GoogleTest requires at least C++14 +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) +``` + +The above configuration declares a dependency on GoogleTest which is downloaded +from GitHub. In the above example, `03597a01ee50ed33e9dfd640b249b4be3799d395` is +the Git commit hash of the GoogleTest version to use; we recommend updating the +hash often to point to the latest version. + +For more information about how to create `CMakeLists.txt` files, see the +[CMake Tutorial](https://cmake.org/cmake/help/latest/guide/tutorial/index.html). + +## Create and run a binary + +With GoogleTest declared as a dependency, you can use GoogleTest code within +your own project. + +As an example, create a file named `hello_test.cc` in your `my_project` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, add the following to the end of your `CMakeLists.txt` file: + +```cmake +enable_testing() + +add_executable( + hello_test + hello_test.cc +) +target_link_libraries( + hello_test + GTest::gtest_main +) + +include(GoogleTest) +gtest_discover_tests(hello_test) +``` + +The above configuration enables testing in CMake, declares the C++ test binary +you want to build (`hello_test`), and links it to GoogleTest (`gtest_main`). The +last two lines enable CMake's test runner to discover the tests included in the +binary, using the +[`GoogleTest` CMake module](https://cmake.org/cmake/help/git-stage/module/GoogleTest.html). + +Now you can build and run your test: + +
+my_project$ cmake -S . -B build
+-- The C compiler identification is GNU 10.2.1
+-- The CXX compiler identification is GNU 10.2.1
+...
+-- Build files have been written to: .../my_project/build
+
+my_project$ cmake --build build
+Scanning dependencies of target gtest
+...
+[100%] Built target gmock_main
+
+my_project$ cd build && ctest
+Test project .../my_project/build
+    Start 1: HelloTest.BasicAssertions
+1/1 Test #1: HelloTest.BasicAssertions ........   Passed    0.00 sec
+
+100% tests passed, 0 tests failed out of 1
+
+Total Test time (real) =   0.01 sec
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/third_party/googletest/docs/reference/actions.md b/third_party/googletest/docs/reference/actions.md new file mode 100644 index 0000000..ab81a12 --- /dev/null +++ b/third_party/googletest/docs/reference/actions.md @@ -0,0 +1,115 @@ +# Actions Reference + +[**Actions**](../gmock_for_dummies.md#actions-what-should-it-do) specify what a +mock function should do when invoked. This page lists the built-in actions +provided by GoogleTest. All actions are defined in the `::testing` namespace. + +## Returning a Value + +| Action | Description | +| :-------------------------------- | :-------------------------------------------- | +| `Return()` | Return from a `void` mock function. | +| `Return(value)` | Return `value`. If the type of `value` is different to the mock function's return type, `value` is converted to the latter type at the time the expectation is set, not when the action is executed. | +| `ReturnArg()` | Return the `N`-th (0-based) argument. | +| `ReturnNew(a1, ..., ak)` | Return `new T(a1, ..., ak)`; a different object is created each time. | +| `ReturnNull()` | Return a null pointer. | +| `ReturnPointee(ptr)` | Return the value pointed to by `ptr`. | +| `ReturnRef(variable)` | Return a reference to `variable`. | +| `ReturnRefOfCopy(value)` | Return a reference to a copy of `value`; the copy lives as long as the action. | +| `ReturnRoundRobin({a1, ..., ak})` | Each call will return the next `ai` in the list, starting at the beginning when the end of the list is reached. | + +## Side Effects + +| Action | Description | +| :--------------------------------- | :-------------------------------------- | +| `Assign(&variable, value)` | Assign `value` to variable. | +| `DeleteArg()` | Delete the `N`-th (0-based) argument, which must be a pointer. | +| `SaveArg(pointer)` | Save the `N`-th (0-based) argument to `*pointer`. | +| `SaveArgPointee(pointer)` | Save the value pointed to by the `N`-th (0-based) argument to `*pointer`. | +| `SetArgReferee(value)` | Assign `value` to the variable referenced by the `N`-th (0-based) argument. | +| `SetArgPointee(value)` | Assign `value` to the variable pointed by the `N`-th (0-based) argument. | +| `SetArgumentPointee(value)` | Same as `SetArgPointee(value)`. Deprecated. Will be removed in v1.7.0. | +| `SetArrayArgument(first, last)` | Copies the elements in source range [`first`, `last`) to the array pointed to by the `N`-th (0-based) argument, which can be either a pointer or an iterator. The action does not take ownership of the elements in the source range. | +| `SetErrnoAndReturn(error, value)` | Set `errno` to `error` and return `value`. | +| `Throw(exception)` | Throws the given exception, which can be any copyable value. Available since v1.1.0. | + +## Using a Function, Functor, or Lambda as an Action + +In the following, by "callable" we mean a free function, `std::function`, +functor, or lambda. + +| Action | Description | +| :---------------------------------- | :------------------------------------- | +| `f` | Invoke `f` with the arguments passed to the mock function, where `f` is a callable. | +| `Invoke(f)` | Invoke `f` with the arguments passed to the mock function, where `f` can be a global/static function or a functor. | +| `Invoke(object_pointer, &class::method)` | Invoke the method on the object with the arguments passed to the mock function. | +| `InvokeWithoutArgs(f)` | Invoke `f`, which can be a global/static function or a functor. `f` must take no arguments. | +| `InvokeWithoutArgs(object_pointer, &class::method)` | Invoke the method on the object, which takes no arguments. | +| `InvokeArgument(arg1, arg2, ..., argk)` | Invoke the mock function's `N`-th (0-based) argument, which must be a function or a functor, with the `k` arguments. | + +The return value of the invoked function is used as the return value of the +action. + +When defining a callable to be used with `Invoke*()`, you can declare any unused +parameters as `Unused`: + +```cpp +using ::testing::Invoke; +double Distance(Unused, double x, double y) { return sqrt(x*x + y*y); } +... +EXPECT_CALL(mock, Foo("Hi", _, _)).WillOnce(Invoke(Distance)); +``` + +`Invoke(callback)` and `InvokeWithoutArgs(callback)` take ownership of +`callback`, which must be permanent. The type of `callback` must be a base +callback type instead of a derived one, e.g. + +```cpp + BlockingClosure* done = new BlockingClosure; + ... Invoke(done) ...; // This won't compile! + + Closure* done2 = new BlockingClosure; + ... Invoke(done2) ...; // This works. +``` + +In `InvokeArgument(...)`, if an argument needs to be passed by reference, +wrap it inside `std::ref()`. For example, + +```cpp +using ::testing::InvokeArgument; +... +InvokeArgument<2>(5, string("Hi"), std::ref(foo)) +``` + +calls the mock function's #2 argument, passing to it `5` and `string("Hi")` by +value, and `foo` by reference. + +## Default Action + +| Action | Description | +| :------------ | :----------------------------------------------------- | +| `DoDefault()` | Do the default action (specified by `ON_CALL()` or the built-in one). | + +{: .callout .note} +**Note:** due to technical reasons, `DoDefault()` cannot be used inside a +composite action - trying to do so will result in a run-time error. + +## Composite Actions + +| Action | Description | +| :----------------------------- | :------------------------------------------ | +| `DoAll(a1, a2, ..., an)` | Do all actions `a1` to `an` and return the result of `an` in each invocation. The first `n - 1` sub-actions must return void and will receive a readonly view of the arguments. | +| `IgnoreResult(a)` | Perform action `a` and ignore its result. `a` must not return void. | +| `WithArg(a)` | Pass the `N`-th (0-based) argument of the mock function to action `a` and perform it. | +| `WithArgs(a)` | Pass the selected (0-based) arguments of the mock function to action `a` and perform it. | +| `WithoutArgs(a)` | Perform action `a` without any arguments. | + +## Defining Actions + +| Macro | Description | +| :--------------------------------- | :-------------------------------------- | +| `ACTION(Sum) { return arg0 + arg1; }` | Defines an action `Sum()` to return the sum of the mock function's argument #0 and #1. | +| `ACTION_P(Plus, n) { return arg0 + n; }` | Defines an action `Plus(n)` to return the sum of the mock function's argument #0 and `n`. | +| `ACTION_Pk(Foo, p1, ..., pk) { statements; }` | Defines a parameterized action `Foo(p1, ..., pk)` to execute the given `statements`. | + +The `ACTION*` macros cannot be used inside a function or class. diff --git a/third_party/googletest/docs/reference/assertions.md b/third_party/googletest/docs/reference/assertions.md new file mode 100644 index 0000000..aa1dbc0 --- /dev/null +++ b/third_party/googletest/docs/reference/assertions.md @@ -0,0 +1,633 @@ +# Assertions Reference + +This page lists the assertion macros provided by GoogleTest for verifying code +behavior. To use them, include the header `gtest/gtest.h`. + +The majority of the macros listed below come as a pair with an `EXPECT_` variant +and an `ASSERT_` variant. Upon failure, `EXPECT_` macros generate nonfatal +failures and allow the current function to continue running, while `ASSERT_` +macros generate fatal failures and abort the current function. + +All assertion macros support streaming a custom failure message into them with +the `<<` operator, for example: + +```cpp +EXPECT_TRUE(my_condition) << "My condition is not true"; +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro—in particular, C strings and string objects. If a wide string (`wchar_t*`, +`TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is streamed to an +assertion, it will be translated to UTF-8 when printed. + +## Explicit Success and Failure {#success-failure} + +The assertions in this section generate a success or failure directly instead of +testing a value or expression. These are useful when control flow, rather than a +Boolean expression, determines the test's success or failure, as shown by the +following example: + +```c++ +switch(expression) { + case 1: + ... some checks ... + case 2: + ... some other checks ... + default: + FAIL() << "We shouldn't get here."; +} +``` + +### SUCCEED {#SUCCEED} + +`SUCCEED()` + +Generates a success. This *does not* make the overall test succeed. A test is +considered successful only if none of its assertions fail during its execution. + +The `SUCCEED` assertion is purely documentary and currently doesn't generate any +user-visible output. However, we may add `SUCCEED` messages to GoogleTest output +in the future. + +### FAIL {#FAIL} + +`FAIL()` + +Generates a fatal failure, which returns from the current function. + +Can only be used in functions that return `void`. See +[Assertion Placement](../advanced.md#assertion-placement) for more information. + +### ADD_FAILURE {#ADD_FAILURE} + +`ADD_FAILURE()` + +Generates a nonfatal failure, which allows the current function to continue +running. + +### ADD_FAILURE_AT {#ADD_FAILURE_AT} + +`ADD_FAILURE_AT(`*`file_path`*`,`*`line_number`*`)` + +Generates a nonfatal failure at the file and line number specified. + +## Generalized Assertion {#generalized} + +The following assertion allows [matchers](matchers.md) to be used to verify +values. + +### EXPECT_THAT {#EXPECT_THAT} + +`EXPECT_THAT(`*`value`*`,`*`matcher`*`)` \ +`ASSERT_THAT(`*`value`*`,`*`matcher`*`)` + +Verifies that *`value`* matches the [matcher](matchers.md) *`matcher`*. + +For example, the following code verifies that the string `value1` starts with +`"Hello"`, `value2` matches a regular expression, and `value3` is between 5 and +10: + +```cpp +#include + +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Lt; +using ::testing::MatchesRegex; +using ::testing::StartsWith; + +... +EXPECT_THAT(value1, StartsWith("Hello")); +EXPECT_THAT(value2, MatchesRegex("Line \\d+")); +ASSERT_THAT(value3, AllOf(Gt(5), Lt(10))); +``` + +Matchers enable assertions of this form to read like English and generate +informative failure messages. For example, if the above assertion on `value1` +fails, the resulting message will be similar to the following: + +``` +Value of: value1 + Actual: "Hi, world!" +Expected: starts with "Hello" +``` + +GoogleTest provides a built-in library of matchers—see the +[Matchers Reference](matchers.md). It is also possible to write your own +matchers—see [Writing New Matchers Quickly](../gmock_cook_book.md#NewMatchers). +The use of matchers makes `EXPECT_THAT` a powerful, extensible assertion. + +*The idea for this assertion was borrowed from Joe Walnes' Hamcrest project, +which adds `assertThat()` to JUnit.* + +## Boolean Conditions {#boolean} + +The following assertions test Boolean conditions. + +### EXPECT_TRUE {#EXPECT_TRUE} + +`EXPECT_TRUE(`*`condition`*`)` \ +`ASSERT_TRUE(`*`condition`*`)` + +Verifies that *`condition`* is true. + +### EXPECT_FALSE {#EXPECT_FALSE} + +`EXPECT_FALSE(`*`condition`*`)` \ +`ASSERT_FALSE(`*`condition`*`)` + +Verifies that *`condition`* is false. + +## Binary Comparison {#binary-comparison} + +The following assertions compare two values. The value arguments must be +comparable by the assertion's comparison operator, otherwise a compiler error +will result. + +If an argument supports the `<<` operator, it will be called to print the +argument when the assertion fails. Otherwise, GoogleTest will attempt to print +them in the best way it can—see +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values). + +Arguments are always evaluated exactly once, so it's OK for the arguments to +have side effects. However, the argument evaluation order is undefined and +programs should not depend on any particular argument evaluation order. + +These assertions work with both narrow and wide string objects (`string` and +`wstring`). + +See also the [Floating-Point Comparison](#floating-point) assertions to compare +floating-point numbers and avoid problems caused by rounding. + +### EXPECT_EQ {#EXPECT_EQ} + +`EXPECT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`==`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in the same memory location, not if they have the same value. Use +[`EXPECT_STREQ`](#EXPECT_STREQ) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_EQ(`*`ptr`*`, nullptr)` instead +of `EXPECT_EQ(`*`ptr`*`, NULL)`. + +### EXPECT_NE {#EXPECT_NE} + +`EXPECT_NE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_NE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`!=`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in different memory locations, not if they have different values. Use +[`EXPECT_STRNE`](#EXPECT_STRNE) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_NE(`*`ptr`*`, nullptr)` instead +of `EXPECT_NE(`*`ptr`*`, NULL)`. + +### EXPECT_LT {#EXPECT_LT} + +`EXPECT_LT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<`*`val2`*. + +### EXPECT_LE {#EXPECT_LE} + +`EXPECT_LE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<=`*`val2`*. + +### EXPECT_GT {#EXPECT_GT} + +`EXPECT_GT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>`*`val2`*. + +### EXPECT_GE {#EXPECT_GE} + +`EXPECT_GE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>=`*`val2`*. + +## String Comparison {#c-strings} + +The following assertions compare two **C strings**. To compare two `string` +objects, use [`EXPECT_EQ`](#EXPECT_EQ) or [`EXPECT_NE`](#EXPECT_NE) instead. + +These assertions also accept wide C strings (`wchar_t*`). If a comparison of two +wide strings fails, their values will be printed as UTF-8 narrow strings. + +To compare a C string with `NULL`, use `EXPECT_EQ(`*`c_string`*`, nullptr)` or +`EXPECT_NE(`*`c_string`*`, nullptr)`. + +### EXPECT_STREQ {#EXPECT_STREQ} + +`EXPECT_STREQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STREQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents. + +### EXPECT_STRNE {#EXPECT_STRNE} + +`EXPECT_STRNE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRNE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents. + +### EXPECT_STRCASEEQ {#EXPECT_STRCASEEQ} + +`EXPECT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents, +ignoring case. + +### EXPECT_STRCASENE {#EXPECT_STRCASENE} + +`EXPECT_STRCASENE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASENE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents, +ignoring case. + +## Floating-Point Comparison {#floating-point} + +The following assertions compare two floating-point values. + +Due to rounding errors, it is very unlikely that two floating-point values will +match exactly, so `EXPECT_EQ` is not suitable. In general, for floating-point +comparison to make sense, the user needs to carefully choose the error bound. + +GoogleTest also provides assertions that use a default error bound based on +Units in the Last Place (ULPs). To learn more about ULPs, see the article +[Comparing Floating Point Numbers](https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/). + +### EXPECT_FLOAT_EQ {#EXPECT_FLOAT_EQ} + +`EXPECT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `float` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. + +### EXPECT_DOUBLE_EQ {#EXPECT_DOUBLE_EQ} + +`EXPECT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `double` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. + +### EXPECT_NEAR {#EXPECT_NEAR} + +`EXPECT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` \ +`ASSERT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` + +Verifies that the difference between *`val1`* and *`val2`* does not exceed the +absolute error bound *`abs_error`*. + +## Exception Assertions {#exceptions} + +The following assertions verify that a piece of code throws, or does not throw, +an exception. Usage requires exceptions to be enabled in the build environment. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_NO_THROW({ + int n = 5; + DoSomething(&n); +}); +``` + +### EXPECT_THROW {#EXPECT_THROW} + +`EXPECT_THROW(`*`statement`*`,`*`exception_type`*`)` \ +`ASSERT_THROW(`*`statement`*`,`*`exception_type`*`)` + +Verifies that *`statement`* throws an exception of type *`exception_type`*. + +### EXPECT_ANY_THROW {#EXPECT_ANY_THROW} + +`EXPECT_ANY_THROW(`*`statement`*`)` \ +`ASSERT_ANY_THROW(`*`statement`*`)` + +Verifies that *`statement`* throws an exception of any type. + +### EXPECT_NO_THROW {#EXPECT_NO_THROW} + +`EXPECT_NO_THROW(`*`statement`*`)` \ +`ASSERT_NO_THROW(`*`statement`*`)` + +Verifies that *`statement`* does not throw any exception. + +## Predicate Assertions {#predicates} + +The following assertions enable more complex predicates to be verified while +printing a more clear failure message than if `EXPECT_TRUE` were used alone. + +### EXPECT_PRED* {#EXPECT_PRED} + +`EXPECT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`EXPECT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`EXPECT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`ASSERT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`ASSERT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred`* returns `true` when passed the given values +as arguments. + +The parameter *`pred`* is a function or functor that accepts as many arguments +as the corresponding macro accepts values. If *`pred`* returns `true` for the +given arguments, the assertion succeeds, otherwise the assertion fails. + +When the assertion fails, it prints the value of each argument. Arguments are +always evaluated exactly once. + +As an example, see the following code: + +```cpp +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED2(MutuallyPrime, a, b); // Succeeds +EXPECT_PRED2(MutuallyPrime, b, c); // Fails +``` + +In the above example, the first assertion succeeds, and the second fails with +the following message: + +``` +MutuallyPrime(b, c) is false, where +b is 4 +c is 10 +``` + +Note that if the given predicate is an overloaded function or a function +template, the assertion macro might not be able to determine which version to +use, and it might be necessary to explicitly specify the type of the function. +For example, for a Boolean function `IsPositive()` overloaded to take either a +single `int` or `double` argument, it would be necessary to write one of the +following: + +```cpp +EXPECT_PRED1(static_cast(IsPositive), 5); +EXPECT_PRED1(static_cast(IsPositive), 3.14); +``` + +Writing simply `EXPECT_PRED1(IsPositive, 5);` would result in a compiler error. +Similarly, to use a template function, specify the template arguments: + +```cpp +template +bool IsNegative(T x) { + return x < 0; +} +... +EXPECT_PRED1(IsNegative, -5); // Must specify type for IsNegative +``` + +If a template has multiple parameters, wrap the predicate in parentheses so the +macro arguments are parsed correctly: + +```cpp +ASSERT_PRED2((MyPredicate), 5, 0); +``` + +### EXPECT_PRED_FORMAT* {#EXPECT_PRED_FORMAT} + +`EXPECT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`EXPECT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`EXPECT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`ASSERT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`ASSERT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred_formatter`* succeeds when passed the given +values as arguments. + +The parameter *`pred_formatter`* is a *predicate-formatter*, which is a function +or functor with the signature: + +```cpp +testing::AssertionResult PredicateFormatter(const char* expr1, + const char* expr2, + ... + const char* exprn, + T1 val1, + T2 val2, + ... + Tn valn); +``` + +where *`val1`*, *`val2`*, ..., *`valn`* are the values of the predicate +arguments, and *`expr1`*, *`expr2`*, ..., *`exprn`* are the corresponding +expressions as they appear in the source code. The types `T1`, `T2`, ..., `Tn` +can be either value types or reference types; if an argument has type `T`, it +can be declared as either `T` or `const T&`, whichever is appropriate. For more +about the return type `testing::AssertionResult`, see +[Using a Function That Returns an AssertionResult](../advanced.md#using-a-function-that-returns-an-assertionresult). + +As an example, see the following code: + +```cpp +// Returns the smallest prime common divisor of m and n, +// or 1 when m and n are mutually prime. +int SmallestPrimeCommonDivisor(int m, int n) { ... } + +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } + +// A predicate-formatter for asserting that two integers are mutually prime. +testing::AssertionResult AssertMutuallyPrime(const char* m_expr, + const char* n_expr, + int m, + int n) { + if (MutuallyPrime(m, n)) return testing::AssertionSuccess(); + + return testing::AssertionFailure() << m_expr << " and " << n_expr + << " (" << m << " and " << n << ") are not mutually prime, " + << "as they have a common divisor " << SmallestPrimeCommonDivisor(m, n); +} + +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, a, b); // Succeeds +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, b, c); // Fails +``` + +In the above example, the final assertion fails and the predicate-formatter +produces the following failure message: + +``` +b and c (4 and 10) are not mutually prime, as they have a common divisor 2 +``` + +## Windows HRESULT Assertions {#HRESULT} + +The following assertions test for `HRESULT` success or failure. For example: + +```cpp +CComPtr shell; +ASSERT_HRESULT_SUCCEEDED(shell.CoCreateInstance(L"Shell.Application")); +CComVariant empty; +ASSERT_HRESULT_SUCCEEDED(shell->ShellExecute(CComBSTR(url), empty, empty, empty, empty)); +``` + +The generated output contains the human-readable error message associated with +the returned `HRESULT` code. + +### EXPECT_HRESULT_SUCCEEDED {#EXPECT_HRESULT_SUCCEEDED} + +`EXPECT_HRESULT_SUCCEEDED(`*`expression`*`)` \ +`ASSERT_HRESULT_SUCCEEDED(`*`expression`*`)` + +Verifies that *`expression`* is a success `HRESULT`. + +### EXPECT_HRESULT_FAILED {#EXPECT_HRESULT_FAILED} + +`EXPECT_HRESULT_FAILED(`*`expression`*`)` \ +`ASSERT_HRESULT_FAILED(`*`expression`*`)` + +Verifies that *`expression`* is a failure `HRESULT`. + +## Death Assertions {#death} + +The following assertions verify that a piece of code causes the process to +terminate. For context, see [Death Tests](../advanced.md#death-tests). + +These assertions spawn a new process and execute the code under test in that +process. How that happens depends on the platform and the variable +`::testing::GTEST_FLAG(death_test_style)`, which is initialized from the +command-line flag `--gtest_death_test_style`. + +* On POSIX systems, `fork()` (or `clone()` on Linux) is used to spawn the + child, after which: + * If the variable's value is `"fast"`, the death test statement is + immediately executed. + * If the variable's value is `"threadsafe"`, the child process re-executes + the unit test binary just as it was originally invoked, but with some + extra flags to cause just the single death test under consideration to + be run. +* On Windows, the child is spawned using the `CreateProcess()` API, and + re-executes the binary to cause just the single death test under + consideration to be run - much like the `"threadsafe"` mode on POSIX. + +Other values for the variable are illegal and will cause the death test to fail. +Currently, the flag's default value is +**`"fast"`**. + +If the death test statement runs to completion without dying, the child process +will nonetheless terminate, and the assertion fails. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_DEATH({ + int n = 5; + DoSomething(&n); +}, "Error on line .* of DoSomething()"); +``` + +### EXPECT_DEATH {#EXPECT_DEATH} + +`EXPECT_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH(`*`statement`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with a nonzero exit +status and produces `stderr` output that matches *`matcher`*. + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `DoSomething(42)` causes +the process to die with an error message that contains the text `My error`: + +```cpp +EXPECT_DEATH(DoSomething(42), "My error"); +``` + +### EXPECT_DEATH_IF_SUPPORTED {#EXPECT_DEATH_IF_SUPPORTED} + +`EXPECT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` + +If death tests are supported, behaves the same as +[`EXPECT_DEATH`](#EXPECT_DEATH). Otherwise, verifies nothing. + +### EXPECT_DEBUG_DEATH {#EXPECT_DEBUG_DEATH} + +`EXPECT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` + +In debug mode, behaves the same as [`EXPECT_DEATH`](#EXPECT_DEATH). When not in +debug mode (i.e. `NDEBUG` is defined), just executes *`statement`*. + +### EXPECT_EXIT {#EXPECT_EXIT} + +`EXPECT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` \ +`ASSERT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with an exit status +that satisfies *`predicate`*, and produces `stderr` output that matches +*`matcher`*. + +The parameter *`predicate`* is a function or functor that accepts an `int` exit +status and returns a `bool`. GoogleTest provides two predicates to handle common +cases: + +```cpp +// Returns true if the program exited normally with the given exit status code. +::testing::ExitedWithCode(exit_code); + +// Returns true if the program was killed by the given signal. +// Not available on Windows. +::testing::KilledBySignal(signal_number); +``` + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `NormalExit()` causes the +process to print a message containing the text `Success` to `stderr` and exit +with exit status code 0: + +```cpp +EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +``` diff --git a/third_party/googletest/docs/reference/matchers.md b/third_party/googletest/docs/reference/matchers.md new file mode 100644 index 0000000..243e3f9 --- /dev/null +++ b/third_party/googletest/docs/reference/matchers.md @@ -0,0 +1,302 @@ +# Matchers Reference + +A **matcher** matches a *single* argument. You can use it inside `ON_CALL()` or +`EXPECT_CALL()`, or use it to validate a value directly using two macros: + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `EXPECT_THAT(actual_value, matcher)` | Asserts that `actual_value` matches `matcher`. | +| `ASSERT_THAT(actual_value, matcher)` | The same as `EXPECT_THAT(actual_value, matcher)`, except that it generates a **fatal** failure. | + +{: .callout .warning} +**WARNING:** Equality matching via `EXPECT_THAT(actual_value, expected_value)` +is supported, however note that implicit conversions can cause surprising +results. For example, `EXPECT_THAT(some_bool, "some string")` will compile and +may pass unintentionally. + +**BEST PRACTICE:** Prefer to make the comparison explicit via +`EXPECT_THAT(actual_value, Eq(expected_value))` or `EXPECT_EQ(actual_value, +expected_value)`. + +Built-in matchers (where `argument` is the function argument, e.g. +`actual_value` in the example above, or when used in the context of +`EXPECT_CALL(mock_object, method(matchers))`, the arguments of `method`) are +divided into several categories. All matchers are defined in the `::testing` +namespace unless otherwise noted. + +## Wildcard + +Matcher | Description +:-------------------------- | :----------------------------------------------- +`_` | `argument` can be any value of the correct type. +`A()` or `An()` | `argument` can be any value of type `type`. + +## Generic Comparison + +| Matcher | Description | +| :--------------------- | :-------------------------------------------------- | +| `Eq(value)` or `value` | `argument == value` | +| `Ge(value)` | `argument >= value` | +| `Gt(value)` | `argument > value` | +| `Le(value)` | `argument <= value` | +| `Lt(value)` | `argument < value` | +| `Ne(value)` | `argument != value` | +| `IsFalse()` | `argument` evaluates to `false` in a Boolean context. | +| `IsTrue()` | `argument` evaluates to `true` in a Boolean context. | +| `IsNull()` | `argument` is a `NULL` pointer (raw or smart). | +| `NotNull()` | `argument` is a non-null pointer (raw or smart). | +| `Optional(m)` | `argument` is `optional<>` that contains a value matching `m`. (For testing whether an `optional<>` is set, check for equality with `nullopt`. You may need to use `Eq(nullopt)` if the inner type doesn't have `==`.)| +| `VariantWith(m)` | `argument` is `variant<>` that holds the alternative of type T with a value matching `m`. | +| `Ref(variable)` | `argument` is a reference to `variable`. | +| `TypedEq(value)` | `argument` has type `type` and is equal to `value`. You may need to use this instead of `Eq(value)` when the mock function is overloaded. | + +Except `Ref()`, these matchers make a *copy* of `value` in case it's modified or +destructed later. If the compiler complains that `value` doesn't have a public +copy constructor, try wrap it in `std::ref()`, e.g. +`Eq(std::ref(non_copyable_value))`. If you do that, make sure +`non_copyable_value` is not changed afterwards, or the meaning of your matcher +will be changed. + +`IsTrue` and `IsFalse` are useful when you need to use a matcher, or for types +that can be explicitly converted to Boolean, but are not implicitly converted to +Boolean. In other cases, you can use the basic +[`EXPECT_TRUE` and `EXPECT_FALSE`](assertions.md#boolean) assertions. + +## Floating-Point Matchers {#FpMatchers} + +| Matcher | Description | +| :------------------------------- | :--------------------------------- | +| `DoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as unequal. | +| `FloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as unequal. | +| `NanSensitiveDoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as equal. | +| `NanSensitiveFloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as equal. | +| `IsNan()` | `argument` is any floating-point type with a NaN value. | + +The above matchers use ULP-based comparison (the same as used in googletest). +They automatically pick a reasonable error bound based on the absolute value of +the expected value. `DoubleEq()` and `FloatEq()` conform to the IEEE standard, +which requires comparing two NaNs for equality to return false. The +`NanSensitive*` version instead treats two NaNs as equal, which is often what a +user wants. + +| Matcher | Description | +| :------------------------------------------------ | :----------------------- | +| `DoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `FloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `NanSensitiveDoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as equal. | +| `NanSensitiveFloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as equal. | + +## String Matchers + +The `argument` can be either a C string or a C++ string object: + +| Matcher | Description | +| :---------------------- | :------------------------------------------------- | +| `ContainsRegex(string)` | `argument` matches the given regular expression. | +| `EndsWith(suffix)` | `argument` ends with string `suffix`. | +| `HasSubstr(string)` | `argument` contains `string` as a sub-string. | +| `IsEmpty()` | `argument` is an empty string. | +| `MatchesRegex(string)` | `argument` matches the given regular expression with the match starting at the first character and ending at the last character. | +| `StartsWith(prefix)` | `argument` starts with string `prefix`. | +| `StrCaseEq(string)` | `argument` is equal to `string`, ignoring case. | +| `StrCaseNe(string)` | `argument` is not equal to `string`, ignoring case. | +| `StrEq(string)` | `argument` is equal to `string`. | +| `StrNe(string)` | `argument` is not equal to `string`. | +| `WhenBase64Unescaped(m)` | `argument` is a base-64 escaped string whose unescaped string matches `m`. The web-safe format from [RFC 4648](https://www.rfc-editor.org/rfc/rfc4648#section-5) is supported. | + +`ContainsRegex()` and `MatchesRegex()` take ownership of the `RE` object. They +use the regular expression syntax defined +[here](../advanced.md#regular-expression-syntax). All of these matchers, except +`ContainsRegex()` and `MatchesRegex()` work for wide strings as well. + +## Container Matchers + +Most STL-style containers support `==`, so you can use `Eq(expected_container)` +or simply `expected_container` to match a container exactly. If you want to +write the elements in-line, match them more flexibly, or get more informative +messages, you can use: + +| Matcher | Description | +| :---------------------------------------- | :------------------------------- | +| `BeginEndDistanceIs(m)` | `argument` is a container whose `begin()` and `end()` iterators are separated by a number of increments matching `m`. E.g. `BeginEndDistanceIs(2)` or `BeginEndDistanceIs(Lt(2))`. For containers that define a `size()` method, `SizeIs(m)` may be more efficient. | +| `ContainerEq(container)` | The same as `Eq(container)` except that the failure message also includes which elements are in one container but not the other. | +| `Contains(e)` | `argument` contains an element that matches `e`, which can be either a value or a matcher. | +| `Contains(e).Times(n)` | `argument` contains elements that match `e`, which can be either a value or a matcher, and the number of matches is `n`, which can be either a value or a matcher. Unlike the plain `Contains` and `Each` this allows to check for arbitrary occurrences including testing for absence with `Contains(e).Times(0)`. | +| `Each(e)` | `argument` is a container where *every* element matches `e`, which can be either a value or a matcher. | +| `ElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, where the *i*-th element matches `ei`, which can be a value or a matcher. | +| `ElementsAreArray({e0, e1, ..., en})`, `ElementsAreArray(a_container)`, `ElementsAreArray(begin, end)`, `ElementsAreArray(array)`, or `ElementsAreArray(array, count)` | The same as `ElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `IsEmpty()` | `argument` is an empty container (`container.empty()`). | +| `IsSubsetOf({e0, e1, ..., en})`, `IsSubsetOf(a_container)`, `IsSubsetOf(begin, end)`, `IsSubsetOf(array)`, or `IsSubsetOf(array, count)` | `argument` matches `UnorderedElementsAre(x0, x1, ..., xk)` for some subset `{x0, x1, ..., xk}` of the expected matchers. | +| `IsSupersetOf({e0, e1, ..., en})`, `IsSupersetOf(a_container)`, `IsSupersetOf(begin, end)`, `IsSupersetOf(array)`, or `IsSupersetOf(array, count)` | Some subset of `argument` matches `UnorderedElementsAre(`expected matchers`)`. | +| `Pointwise(m, container)`, `Pointwise(m, {e0, e1, ..., en})` | `argument` contains the same number of elements as in `container`, and for all i, (the i-th element in `argument`, the i-th element in `container`) match `m`, which is a matcher on 2-tuples. E.g. `Pointwise(Le(), upper_bounds)` verifies that each element in `argument` doesn't exceed the corresponding element in `upper_bounds`. See more detail below. | +| `SizeIs(m)` | `argument` is a container whose size matches `m`. E.g. `SizeIs(2)` or `SizeIs(Lt(2))`. | +| `UnorderedElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, and under *some* permutation of the elements, each element matches an `ei` (for a different `i`), which can be a value or a matcher. | +| `UnorderedElementsAreArray({e0, e1, ..., en})`, `UnorderedElementsAreArray(a_container)`, `UnorderedElementsAreArray(begin, end)`, `UnorderedElementsAreArray(array)`, or `UnorderedElementsAreArray(array, count)` | The same as `UnorderedElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `UnorderedPointwise(m, container)`, `UnorderedPointwise(m, {e0, e1, ..., en})` | Like `Pointwise(m, container)`, but ignores the order of elements. | +| `WhenSorted(m)` | When `argument` is sorted using the `<` operator, it matches container matcher `m`. E.g. `WhenSorted(ElementsAre(1, 2, 3))` verifies that `argument` contains elements 1, 2, and 3, ignoring order. | +| `WhenSortedBy(comparator, m)` | The same as `WhenSorted(m)`, except that the given comparator instead of `<` is used to sort `argument`. E.g. `WhenSortedBy(std::greater(), ElementsAre(3, 2, 1))`. | + +**Notes:** + +* These matchers can also match: + 1. a native array passed by reference (e.g. in `Foo(const int (&a)[5])`), + and + 2. an array passed as a pointer and a count (e.g. in `Bar(const T* buffer, + int len)` -- see [Multi-argument Matchers](#MultiArgMatchers)). +* The array being matched may be multi-dimensional (i.e. its elements can be + arrays). +* `m` in `Pointwise(m, ...)` and `UnorderedPointwise(m, ...)` should be a + matcher for `::std::tuple` where `T` and `U` are the element type of + the actual container and the expected container, respectively. For example, + to compare two `Foo` containers where `Foo` doesn't support `operator==`, + one might write: + + ```cpp + MATCHER(FooEq, "") { + return std::get<0>(arg).Equals(std::get<1>(arg)); + } + ... + EXPECT_THAT(actual_foos, Pointwise(FooEq(), expected_foos)); + ``` + +## Member Matchers + +| Matcher | Description | +| :------------------------------ | :----------------------------------------- | +| `Field(&class::field, m)` | `argument.field` (or `argument->field` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. | +| `Field(field_name, &class::field, m)` | The same as the two-parameter version, but provides a better error message. | +| `Key(e)` | `argument.first` matches `e`, which can be either a value or a matcher. E.g. `Contains(Key(Le(5)))` can verify that a `map` contains a key `<= 5`. | +| `Pair(m1, m2)` | `argument` is an `std::pair` whose `first` field matches `m1` and `second` field matches `m2`. | +| `FieldsAre(m...)` | `argument` is a compatible object where each field matches piecewise with the matchers `m...`. A compatible object is any that supports the `std::tuple_size`+`get(obj)` protocol. In C++17 and up this also supports types compatible with structured bindings, like aggregates. | +| `Property(&class::property, m)` | `argument.property()` (or `argument->property()` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. The method `property()` must take no argument and be declared as `const`. | +| `Property(property_name, &class::property, m)` | The same as the two-parameter version, but provides a better error message. + +**Notes:** + +* You can use `FieldsAre()` to match any type that supports structured + bindings, such as `std::tuple`, `std::pair`, `std::array`, and aggregate + types. For example: + + ```cpp + std::tuple my_tuple{7, "hello world"}; + EXPECT_THAT(my_tuple, FieldsAre(Ge(0), HasSubstr("hello"))); + + struct MyStruct { + int value = 42; + std::string greeting = "aloha"; + }; + MyStruct s; + EXPECT_THAT(s, FieldsAre(42, "aloha")); + ``` + +* Don't use `Property()` against member functions that you do not own, because + taking addresses of functions is fragile and generally not part of the + contract of the function. + +## Matching the Result of a Function, Functor, or Callback + +| Matcher | Description | +| :--------------- | :------------------------------------------------ | +| `ResultOf(f, m)` | `f(argument)` matches matcher `m`, where `f` is a function or functor. | +| `ResultOf(result_description, f, m)` | The same as the two-parameter version, but provides a better error message. + +## Pointer Matchers + +| Matcher | Description | +| :------------------------ | :---------------------------------------------- | +| `Address(m)` | the result of `std::addressof(argument)` matches `m`. | +| `Pointee(m)` | `argument` (either a smart pointer or a raw pointer) points to a value that matches matcher `m`. | +| `Pointer(m)` | `argument` (either a smart pointer or a raw pointer) contains a pointer that matches `m`. `m` will match against the raw pointer regardless of the type of `argument`. | +| `WhenDynamicCastTo(m)` | when `argument` is passed through `dynamic_cast()`, it matches matcher `m`. | + +## Multi-argument Matchers {#MultiArgMatchers} + +Technically, all matchers match a *single* value. A "multi-argument" matcher is +just one that matches a *tuple*. The following matchers can be used to match a +tuple `(x, y)`: + +Matcher | Description +:------ | :---------- +`Eq()` | `x == y` +`Ge()` | `x >= y` +`Gt()` | `x > y` +`Le()` | `x <= y` +`Lt()` | `x < y` +`Ne()` | `x != y` + +You can use the following selectors to pick a subset of the arguments (or +reorder them) to participate in the matching: + +| Matcher | Description | +| :------------------------- | :---------------------------------------------- | +| `AllArgs(m)` | Equivalent to `m`. Useful as syntactic sugar in `.With(AllArgs(m))`. | +| `Args(m)` | The tuple of the `k` selected (using 0-based indices) arguments matches `m`, e.g. `Args<1, 2>(Eq())`. | + +## Composite Matchers + +You can make a matcher from one or more other matchers: + +| Matcher | Description | +| :------------------------------- | :-------------------------------------- | +| `AllOf(m1, m2, ..., mn)` | `argument` matches all of the matchers `m1` to `mn`. | +| `AllOfArray({m0, m1, ..., mn})`, `AllOfArray(a_container)`, `AllOfArray(begin, end)`, `AllOfArray(array)`, or `AllOfArray(array, count)` | The same as `AllOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `AnyOf(m1, m2, ..., mn)` | `argument` matches at least one of the matchers `m1` to `mn`. | +| `AnyOfArray({m0, m1, ..., mn})`, `AnyOfArray(a_container)`, `AnyOfArray(begin, end)`, `AnyOfArray(array)`, or `AnyOfArray(array, count)` | The same as `AnyOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `Not(m)` | `argument` doesn't match matcher `m`. | +| `Conditional(cond, m1, m2)` | Matches matcher `m1` if `cond` evaluates to true, else matches `m2`.| + +## Adapters for Matchers + +| Matcher | Description | +| :---------------------- | :------------------------------------ | +| `MatcherCast(m)` | casts matcher `m` to type `Matcher`. | +| `SafeMatcherCast(m)` | [safely casts](../gmock_cook_book.md#SafeMatcherCast) matcher `m` to type `Matcher`. | +| `Truly(predicate)` | `predicate(argument)` returns something considered by C++ to be true, where `predicate` is a function or functor. | + +`AddressSatisfies(callback)` and `Truly(callback)` take ownership of `callback`, +which must be a permanent callback. + +## Using Matchers as Predicates {#MatchersAsPredicatesCheat} + +| Matcher | Description | +| :---------------------------- | :------------------------------------------ | +| `Matches(m)(value)` | evaluates to `true` if `value` matches `m`. You can use `Matches(m)` alone as a unary functor. | +| `ExplainMatchResult(m, value, result_listener)` | evaluates to `true` if `value` matches `m`, explaining the result to `result_listener`. | +| `Value(value, m)` | evaluates to `true` if `value` matches `m`. | + +## Defining Matchers + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `MATCHER(IsEven, "") { return (arg % 2) == 0; }` | Defines a matcher `IsEven()` to match an even number. | +| `MATCHER_P(IsDivisibleBy, n, "") { *result_listener << "where the remainder is " << (arg % n); return (arg % n) == 0; }` | Defines a matcher `IsDivisibleBy(n)` to match a number divisible by `n`. | +| `MATCHER_P2(IsBetween, a, b, absl::StrCat(negation ? "isn't" : "is", " between ", PrintToString(a), " and ", PrintToString(b))) { return a <= arg && arg <= b; }` | Defines a matcher `IsBetween(a, b)` to match a value in the range [`a`, `b`]. | + +**Notes:** + +1. The `MATCHER*` macros cannot be used inside a function or class. +2. The matcher body must be *purely functional* (i.e. it cannot have any side + effect, and the result must not depend on anything other than the value + being matched and the matcher parameters). +3. You can use `PrintToString(x)` to convert a value `x` of any type to a + string. +4. You can use `ExplainMatchResult()` in a custom matcher to wrap another + matcher, for example: + + ```cpp + MATCHER_P(NestedPropertyMatches, matcher, "") { + return ExplainMatchResult(matcher, arg.nested().property(), result_listener); + } + ``` + +5. You can use `DescribeMatcher<>` to describe another matcher. For example: + + ```cpp + MATCHER_P(XAndYThat, matcher, + "X that " + DescribeMatcher(matcher, negation) + + (negation ? " or" : " and") + " Y that " + + DescribeMatcher(matcher, negation)) { + return ExplainMatchResult(matcher, arg.x(), result_listener) && + ExplainMatchResult(matcher, arg.y(), result_listener); + } + ``` diff --git a/third_party/googletest/docs/reference/mocking.md b/third_party/googletest/docs/reference/mocking.md new file mode 100644 index 0000000..e414ffb --- /dev/null +++ b/third_party/googletest/docs/reference/mocking.md @@ -0,0 +1,589 @@ +# Mocking Reference + +This page lists the facilities provided by GoogleTest for creating and working +with mock objects. To use them, include the header +`gmock/gmock.h`. + +## Macros {#macros} + +GoogleTest defines the following macros for working with mocks. + +### MOCK_METHOD {#MOCK_METHOD} + +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`));` \ +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`), +(`*`specs...`*`));` + +Defines a mock method *`method_name`* with arguments `(`*`args...`*`)` and +return type *`return_type`* within a mock class. + +The parameters of `MOCK_METHOD` mirror the method declaration. The optional +fourth parameter *`specs...`* is a comma-separated list of qualifiers. The +following qualifiers are accepted: + +| Qualifier | Meaning | +| -------------------------- | -------------------------------------------- | +| `const` | Makes the mocked method a `const` method. Required if overriding a `const` method. | +| `override` | Marks the method with `override`. Recommended if overriding a `virtual` method. | +| `noexcept` | Marks the method with `noexcept`. Required if overriding a `noexcept` method. | +| `Calltype(`*`calltype`*`)` | Sets the call type for the method, for example `Calltype(STDMETHODCALLTYPE)`. Useful on Windows. | +| `ref(`*`qualifier`*`)` | Marks the method with the given reference qualifier, for example `ref(&)` or `ref(&&)`. Required if overriding a method that has a reference qualifier. | + +Note that commas in arguments prevent `MOCK_METHOD` from parsing the arguments +correctly if they are not appropriately surrounded by parentheses. See the +following example: + +```cpp +class MyMock { + public: + // The following 2 lines will not compile due to commas in the arguments: + MOCK_METHOD(std::pair, GetPair, ()); // Error! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Error! + + // One solution - wrap arguments that contain commas in parentheses: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); + + // Another solution - use type aliases: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +`MOCK_METHOD` must be used in the `public:` section of a mock class definition, +regardless of whether the method being mocked is `public`, `protected`, or +`private` in the base class. + +### EXPECT_CALL {#EXPECT_CALL} + +`EXPECT_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Creates an [expectation](../gmock_for_dummies.md#setting-expectations) that the +method *`method_name`* of the object *`mock_object`* is called with arguments +that match the given matchers *`matchers...`*. `EXPECT_CALL` must precede any +code that exercises the mock object. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The expectation will +apply only to calls of *`method_name`* whose arguments match all of the +matchers. If `(`*`matchers...`*`)` is omitted, the expectation behaves as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to modify the expectation, and they +must be used in the following order: + +```cpp +EXPECT_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .Times(cardinality) // Can be used at most once + .InSequence(sequences...) // Can be used any number of times + .After(expectations...) // Can be used any number of times + .WillOnce(action) // Can be used any number of times + .WillRepeatedly(action) // Can be used at most once + .RetiresOnSaturation(); // Can be used at most once +``` + +See details for each modifier clause below. + +#### With {#EXPECT_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the expectation to apply only to mock function calls whose arguments +as a whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the expectation that +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +... +EXPECT_CALL(my_mock, SetPosition(_, _)) + .With(Lt()); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once on an expectation and must be the +first clause. + +#### Times {#EXPECT_CALL.Times} + +`.Times(`*`cardinality`*`)` + +Specifies how many times the mock function call is expected. + +The parameter *`cardinality`* represents the number of expected calls and can be +one of the following, all defined in the `::testing` namespace: + +| Cardinality | Meaning | +| ------------------- | --------------------------------------------------- | +| `AnyNumber()` | The function can be called any number of times. | +| `AtLeast(n)` | The function call is expected at least *n* times. | +| `AtMost(n)` | The function call is expected at most *n* times. | +| `Between(m, n)` | The function call is expected between *m* and *n* times, inclusive. | +| `Exactly(n)` or `n` | The function call is expected exactly *n* times. If *n* is 0, the call should never happen. | + +If the `Times` clause is omitted, GoogleTest infers the cardinality as follows: + +* If neither [`WillOnce`](#EXPECT_CALL.WillOnce) nor + [`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) are specified, the inferred + cardinality is `Times(1)`. +* If there are *n* `WillOnce` clauses and no `WillRepeatedly` clause, where + *n* >= 1, the inferred cardinality is `Times(n)`. +* If there are *n* `WillOnce` clauses and one `WillRepeatedly` clause, where + *n* >= 0, the inferred cardinality is `Times(AtLeast(n))`. + +The `Times` clause can be used at most once on an expectation. + +#### InSequence {#EXPECT_CALL.InSequence} + +`.InSequence(`*`sequences...`*`)` + +Specifies that the mock function call is expected in a certain sequence. + +The parameter *`sequences...`* is any number of [`Sequence`](#Sequence) objects. +Expected calls assigned to the same sequence are expected to occur in the order +the expectations are declared. + +For example, the following code sets the expectation that the `Reset()` method +of `my_mock` is called before both `GetSize()` and `Describe()`, and `GetSize()` +and `Describe()` can occur in any order relative to each other: + +```cpp +using ::testing::Sequence; +Sequence s1, s2; +... +EXPECT_CALL(my_mock, Reset()) + .InSequence(s1, s2); +EXPECT_CALL(my_mock, GetSize()) + .InSequence(s1); +EXPECT_CALL(my_mock, Describe()) + .InSequence(s2); +``` + +The `InSequence` clause can be used any number of times on an expectation. + +See also the [`InSequence` class](#InSequence). + +#### After {#EXPECT_CALL.After} + +`.After(`*`expectations...`*`)` + +Specifies that the mock function call is expected to occur after one or more +other calls. + +The parameter *`expectations...`* can be up to five +[`Expectation`](#Expectation) or [`ExpectationSet`](#ExpectationSet) objects. +The mock function call is expected to occur after all of the given expectations. + +For example, the following code sets the expectation that the `Describe()` +method of `my_mock` is called only after both `InitX()` and `InitY()` have been +called. + +```cpp +using ::testing::Expectation; +... +Expectation init_x = EXPECT_CALL(my_mock, InitX()); +Expectation init_y = EXPECT_CALL(my_mock, InitY()); +EXPECT_CALL(my_mock, Describe()) + .After(init_x, init_y); +``` + +The `ExpectationSet` object is helpful when the number of prerequisites for an +expectation is large or variable, for example: + +```cpp +using ::testing::ExpectationSet; +... +ExpectationSet all_inits; +// Collect all expectations of InitElement() calls +for (int i = 0; i < element_count; i++) { + all_inits += EXPECT_CALL(my_mock, InitElement(i)); +} +EXPECT_CALL(my_mock, Describe()) + .After(all_inits); // Expect Describe() call after all InitElement() calls +``` + +The `After` clause can be used any number of times on an expectation. + +#### WillOnce {#EXPECT_CALL.WillOnce} + +`.WillOnce(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for a single +matching function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillOnce` implicitly sets a cardinality on the expectation when +`Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +Each matching function call will perform the next action in the order declared. +For example, the following code specifies that `my_mock.GetNumber()` is expected +to be called exactly 3 times and will return `1`, `2`, and `3` respectively on +the first, second, and third calls: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .WillOnce(Return(3)); +``` + +The `WillOnce` clause can be used any number of times on an expectation. Unlike +`WillRepeatedly`, the action fed to each `WillOnce` call will be called at most +once, so may be a move-only type and/or have an `&&`-qualified call operator. + +#### WillRepeatedly {#EXPECT_CALL.WillRepeatedly} + +`.WillRepeatedly(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for all subsequent +matching function calls. Takes effect after the actions specified in the +[`WillOnce`](#EXPECT_CALL.WillOnce) clauses, if any, have been performed. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillRepeatedly` implicitly sets a cardinality on the expectation +when `Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +If any `WillOnce` clauses have been specified, matching function calls will +perform those actions before the action specified by `WillRepeatedly`. See the +following example: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetName()) + .WillRepeatedly(Return("John Doe")); // Return "John Doe" on all calls + +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(42)) // Return 42 on the first call + .WillRepeatedly(Return(7)); // Return 7 on all subsequent calls +``` + +The `WillRepeatedly` clause can be used at most once on an expectation. + +#### RetiresOnSaturation {#EXPECT_CALL.RetiresOnSaturation} + +`.RetiresOnSaturation()` + +Indicates that the expectation will no longer be active after the expected +number of matching function calls has been reached. + +The `RetiresOnSaturation` clause is only meaningful for expectations with an +upper-bounded cardinality. The expectation will *retire* (no longer match any +function calls) after it has been *saturated* (the upper bound has been +reached). See the following example: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(my_mock, SetNumber(_)) // Expectation 1 + .Times(AnyNumber()); +EXPECT_CALL(my_mock, SetNumber(7)) // Expectation 2 + .Times(2) + .RetiresOnSaturation(); +``` + +In the above example, the first two calls to `my_mock.SetNumber(7)` match +expectation 2, which then becomes inactive and no longer matches any calls. A +third call to `my_mock.SetNumber(7)` would then match expectation 1. Without +`RetiresOnSaturation()` on expectation 2, a third call to `my_mock.SetNumber(7)` +would match expectation 2 again, producing a failure since the limit of 2 calls +was exceeded. + +The `RetiresOnSaturation` clause can be used at most once on an expectation and +must be the last clause. + +### ON_CALL {#ON_CALL} + +`ON_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Defines what happens when the method *`method_name`* of the object +*`mock_object`* is called with arguments that match the given matchers +*`matchers...`*. Requires a modifier clause to specify the method's behavior. +*Does not* set any expectations that the method will be called. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The `ON_CALL` +specification will apply only to calls of *`method_name`* whose arguments match +all of the matchers. If `(`*`matchers...`*`)` is omitted, the behavior is as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to set the method's behavior, and +they must be used in the following order: + +```cpp +ON_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .WillByDefault(action); // Required +``` + +See details for each modifier clause below. + +#### With {#ON_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the specification to only mock function calls whose arguments as a +whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the default behavior when +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... +ON_CALL(my_mock, SetPosition(_, _)) + .With(Lt()) + .WillByDefault(Return(true)); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once with each `ON_CALL` statement. + +#### WillByDefault {#ON_CALL.WillByDefault} + +`.WillByDefault(`*`action`*`)` + +Specifies the default behavior of a matching mock function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +For example, the following code specifies that by default, a call to +`my_mock.Greet()` will return `"hello"`: + +```cpp +using ::testing::Return; +... +ON_CALL(my_mock, Greet()) + .WillByDefault(Return("hello")); +``` + +The action specified by `WillByDefault` is superseded by the actions specified +on a matching `EXPECT_CALL` statement, if any. See the +[`WillOnce`](#EXPECT_CALL.WillOnce) and +[`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) clauses of `EXPECT_CALL`. + +The `WillByDefault` clause must be used exactly once with each `ON_CALL` +statement. + +## Classes {#classes} + +GoogleTest defines the following classes for working with mocks. + +### DefaultValue {#DefaultValue} + +`::testing::DefaultValue` + +Allows a user to specify the default value for a type `T` that is both copyable +and publicly destructible (i.e. anything that can be used as a function return +type). For mock functions with a return type of `T`, this default value is +returned from function calls that do not specify an action. + +Provides the static methods `Set()`, `SetFactory()`, and `Clear()` to manage the +default value: + +```cpp +// Sets the default value to be returned. T must be copy constructible. +DefaultValue::Set(value); + +// Sets a factory. Will be invoked on demand. T must be move constructible. +T MakeT(); +DefaultValue::SetFactory(&MakeT); + +// Unsets the default value. +DefaultValue::Clear(); +``` + +### NiceMock {#NiceMock} + +`::testing::NiceMock` + +Represents a mock object that suppresses warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NiceMock` is analogous to usage of `T`. `NiceMock` is a subclass +of `T`, so it can be used wherever an object of type `T` is accepted. In +addition, `NiceMock` can be constructed with any arguments that a constructor +of `T` accepts. + +For example, the following code suppresses warnings on the mock `my_mock` of +type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NiceMock; +... +NiceMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`NiceMock` only works for mock methods defined using the `MOCK_METHOD` macro +directly in the definition of class `T`. If a mock method is defined in a base +class of `T`, a warning might still be generated. + +`NiceMock` might not work correctly if the destructor of `T` is not virtual. + +### NaggyMock {#NaggyMock} + +`::testing::NaggyMock` + +Represents a mock object that generates warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NaggyMock` is analogous to usage of `T`. `NaggyMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `NaggyMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates warnings on the mock `my_mock` of type +`MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NaggyMock; +... +NaggyMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +Mock objects of type `T` by default behave the same way as `NaggyMock`. + +### StrictMock {#StrictMock} + +`::testing::StrictMock` + +Represents a mock object that generates test failures on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `StrictMock` is analogous to usage of `T`. `StrictMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `StrictMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates a test failure on the mock `my_mock` +of type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::StrictMock; +... +StrictMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`StrictMock` only works for mock methods defined using the `MOCK_METHOD` +macro directly in the definition of class `T`. If a mock method is defined in a +base class of `T`, a failure might not be generated. + +`StrictMock` might not work correctly if the destructor of `T` is not +virtual. + +### Sequence {#Sequence} + +`::testing::Sequence` + +Represents a chronological sequence of expectations. See the +[`InSequence`](#EXPECT_CALL.InSequence) clause of `EXPECT_CALL` for usage. + +### InSequence {#InSequence} + +`::testing::InSequence` + +An object of this type causes all expectations encountered in its scope to be +put in an anonymous sequence. + +This allows more convenient expression of multiple expectations in a single +sequence: + +```cpp +using ::testing::InSequence; +{ + InSequence seq; + + // The following are expected to occur in the order declared. + EXPECT_CALL(...); + EXPECT_CALL(...); + ... + EXPECT_CALL(...); +} +``` + +The name of the `InSequence` object does not matter. + +### Expectation {#Expectation} + +`::testing::Expectation` + +Represents a mock function call expectation as created by +[`EXPECT_CALL`](#EXPECT_CALL): + +```cpp +using ::testing::Expectation; +Expectation my_expectation = EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. + +### ExpectationSet {#ExpectationSet} + +`::testing::ExpectationSet` + +Represents a set of mock function call expectations. + +Use the `+=` operator to add [`Expectation`](#Expectation) objects to the set: + +```cpp +using ::testing::ExpectationSet; +ExpectationSet my_expectations; +my_expectations += EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. diff --git a/third_party/googletest/docs/reference/testing.md b/third_party/googletest/docs/reference/testing.md new file mode 100644 index 0000000..17225a6 --- /dev/null +++ b/third_party/googletest/docs/reference/testing.md @@ -0,0 +1,1432 @@ +# Testing Reference + + + +This page lists the facilities provided by GoogleTest for writing test programs. +To use them, include the header `gtest/gtest.h`. + +## Macros + +GoogleTest defines the following macros for writing tests. + +### TEST {#TEST} + +
+TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* in the test suite +*`TestSuiteName`*, consisting of the given statements. + +Both arguments *`TestSuiteName`* and *`TestName`* must be valid C++ identifiers +and must not contain underscores (`_`). Tests in different test suites can have +the same individual name. + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_F {#TEST_F} + +
+TEST_F(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* that uses the test fixture class +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a test fixture class—see +[Test Fixtures](../primer.md#same-data-multiple-tests). + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_P {#TEST_P} + +
+TEST_P(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual value-parameterized test named *`TestName`* that uses the +test fixture class *`TestFixtureName`*. The test suite name is +*`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a value-parameterized test fixture class—see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +The statements within the test body can be any code under test. Within the test +body, the test parameter can be accessed with the `GetParam()` function (see +[`WithParamInterface`](#WithParamInterface)). For example: + +```cpp +TEST_P(MyTestSuite, DoesSomething) { + ... + EXPECT_TRUE(DoSomething(GetParam())); + ... +} +``` + +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +See also [`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). + +### INSTANTIATE_TEST_SUITE_P {#INSTANTIATE_TEST_SUITE_P} + +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`)` +\ +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`,`*`name_generator`*`)` + +Instantiates the value-parameterized test suite *`TestSuiteName`* (defined with +[`TEST_P`](#TEST_P)). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. + +The argument *`param_generator`* is one of the following GoogleTest-provided +functions that generate the test parameters, all defined in the `::testing` +namespace: + + + +| Parameter Generator | Behavior | +| ------------------- | ---------------------------------------------------- | +| `Range(begin, end [, step])` | Yields values `{begin, begin+step, begin+step+step, ...}`. The values do not include `end`. `step` defaults to 1. | +| `Values(v1, v2, ..., vN)` | Yields values `{v1, v2, ..., vN}`. | +| `ValuesIn(container)` or `ValuesIn(begin,end)` | Yields values from a C-style array, an STL-style container, or an iterator range `[begin, end)`. | +| `Bool()` | Yields sequence `{false, true}`. | +| `Combine(g1, g2, ..., gN)` | Yields as `std::tuple` *n*-tuples all combinations (Cartesian product) of the values generated by the given *n* generators `g1`, `g2`, ..., `gN`. | +| `ConvertGenerator(g)` | Yields values generated by generator `g`, `static_cast` to `T`. | + +The optional last argument *`name_generator`* is a function or functor that +generates custom test name suffixes based on the test parameters. The function +must accept an argument of type +[`TestParamInfo`](#TestParamInfo) and return a `std::string`. +The test name suffix can only contain alphanumeric characters and underscores. +GoogleTest provides [`PrintToStringParamName`](#PrintToStringParamName), or a +custom function can be used for more control: + +```cpp +INSTANTIATE_TEST_SUITE_P( + MyInstantiation, MyTestSuite, + testing::Values(...), + [](const testing::TestParamInfo& info) { + // Can use info.param here to generate the test suffix + std::string name = ... + return name; + }); +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +See also +[`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST`](#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST). + +### TYPED_TEST_SUITE {#TYPED_TEST_SUITE} + +`TYPED_TEST_SUITE(`*`TestFixtureName`*`,`*`Types`*`)` + +Defines a typed test suite based on the test fixture *`TestFixtureName`*. The +test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type, for example: + +```cpp +template +class MyFixture : public testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. + +See also [`TYPED_TEST`](#TYPED_TEST) and +[Typed Tests](../advanced.md#typed-tests) for more information. + +### TYPED_TEST {#TYPED_TEST} + +
+TYPED_TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual typed test named *`TestName`* in the typed test suite +*`TestSuiteName`*. The test suite must be defined with +[`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See the following example: + +```cpp +TYPED_TEST(MyFixture, Example) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of MyFixture via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} +``` + +For more information, see [Typed Tests](../advanced.md#typed-tests). + +### TYPED_TEST_SUITE_P {#TYPED_TEST_SUITE_P} + +`TYPED_TEST_SUITE_P(`*`TestFixtureName`*`)` + +Defines a type-parameterized test suite based on the test fixture +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type. See [`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE) for an example. + +See also [`TYPED_TEST_P`](#TYPED_TEST_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### TYPED_TEST_P {#TYPED_TEST_P} + +
+TYPED_TEST_P(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual type-parameterized test named *`TestName`* in the +type-parameterized test suite *`TestSuiteName`*. The test suite must be defined +with [`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See [`TYPED_TEST`](#TYPED_TEST) +for an example. + +See also [`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### REGISTER_TYPED_TEST_SUITE_P {#REGISTER_TYPED_TEST_SUITE_P} + +`REGISTER_TYPED_TEST_SUITE_P(`*`TestSuiteName`*`,`*`TestNames...`*`)` + +Registers the type-parameterized tests *`TestNames...`* of the test suite +*`TestSuiteName`*. The test suite and tests must be defined with +[`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P) and [`TYPED_TEST_P`](#TYPED_TEST_P). + +For example: + +```cpp +// Define the test suite and tests. +TYPED_TEST_SUITE_P(MyFixture); +TYPED_TEST_P(MyFixture, HasPropertyA) { ... } +TYPED_TEST_P(MyFixture, HasPropertyB) { ... } + +// Register the tests in the test suite. +REGISTER_TYPED_TEST_SUITE_P(MyFixture, HasPropertyA, HasPropertyB); +``` + +See also [`INSTANTIATE_TYPED_TEST_SUITE_P`](#INSTANTIATE_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### INSTANTIATE_TYPED_TEST_SUITE_P {#INSTANTIATE_TYPED_TEST_SUITE_P} + +`INSTANTIATE_TYPED_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`Types`*`)` + +Instantiates the type-parameterized test suite *`TestSuiteName`*. The test suite +must be registered with +[`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(MyInstantiation, MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro to parse correctly. + +For more information, see +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +### FRIEND_TEST {#FRIEND_TEST} + +`FRIEND_TEST(`*`TestSuiteName`*`,`*`TestName`*`)` + +Within a class body, declares an individual test as a friend of the class, +enabling the test to access private class members. + +If the class is defined in a namespace, then in order to be friends of the +class, test fixtures and tests must be defined in the exact same namespace, +without inline or anonymous namespaces. + +For example, if the class definition looks like the following: + +```cpp +namespace my_namespace { + +class MyClass { + friend class MyClassTest; + FRIEND_TEST(MyClassTest, HasPropertyA); + FRIEND_TEST(MyClassTest, HasPropertyB); + ... definition of class MyClass ... +}; + +} // namespace my_namespace +``` + +Then the test code should look like: + +```cpp +namespace my_namespace { + +class MyClassTest : public testing::Test { + ... +}; + +TEST_F(MyClassTest, HasPropertyA) { ... } +TEST_F(MyClassTest, HasPropertyB) { ... } + +} // namespace my_namespace +``` + +See [Testing Private Code](../advanced.md#testing-private-code) for more +information. + +### SCOPED_TRACE {#SCOPED_TRACE} + +`SCOPED_TRACE(`*`message`*`)` + +Causes the current file name, line number, and the given message *`message`* to +be added to the failure message for each assertion failure that occurs in the +scope. + +For more information, see +[Adding Traces to Assertions](../advanced.md#adding-traces-to-assertions). + +See also the [`ScopedTrace` class](#ScopedTrace). + +### GTEST_SKIP {#GTEST_SKIP} + +`GTEST_SKIP()` + +Prevents further test execution at runtime. + +Can be used in individual test cases or in the `SetUp()` methods of test +environments or test fixtures (classes derived from the +[`Environment`](#Environment) or [`Test`](#Test) classes). If used in a global +test environment `SetUp()` method, it skips all tests in the test program. If +used in a test fixture `SetUp()` method, it skips all tests in the corresponding +test suite. + +Similar to assertions, `GTEST_SKIP` allows streaming a custom message into it. + +See [Skipping Test Execution](../advanced.md#skipping-test-execution) for more +information. + +### GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST {#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST} + +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(`*`TestSuiteName`*`)` + +Allows the value-parameterized test suite *`TestSuiteName`* to be +uninstantiated. + +By default, every [`TEST_P`](#TEST_P) call without a corresponding +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P) call causes a failing +test in the test suite `GoogleTestVerification`. +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST` suppresses this failure for the +given test suite. + +## Classes and types + +GoogleTest defines the following classes and types to help with writing tests. + +### AssertionResult {#AssertionResult} + +`testing::AssertionResult` + +A class for indicating whether an assertion was successful. + +When the assertion wasn't successful, the `AssertionResult` object stores a +non-empty failure message that can be retrieved with the object's `message()` +method. + +To create an instance of this class, use one of the factory functions +[`AssertionSuccess()`](#AssertionSuccess) or +[`AssertionFailure()`](#AssertionFailure). + +### AssertionException {#AssertionException} + +`testing::AssertionException` + +Exception which can be thrown from +[`TestEventListener::OnTestPartResult`](#TestEventListener::OnTestPartResult). + +### EmptyTestEventListener {#EmptyTestEventListener} + +`testing::EmptyTestEventListener` + +Provides an empty implementation of all methods in the +[`TestEventListener`](#TestEventListener) interface, such that a subclass only +needs to override the methods it cares about. + +### Environment {#Environment} + +`testing::Environment` + +Represents a global test environment. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down). + +#### Protected Methods {#Environment-protected} + +##### SetUp {#Environment::SetUp} + +`virtual void Environment::SetUp()` + +Override this to define how to set up the environment. + +##### TearDown {#Environment::TearDown} + +`virtual void Environment::TearDown()` + +Override this to define how to tear down the environment. + +### ScopedTrace {#ScopedTrace} + +`testing::ScopedTrace` + +An instance of this class causes a trace to be included in every test failure +message generated by code in the scope of the lifetime of the `ScopedTrace` +instance. The effect is undone with the destruction of the instance. + +The `ScopedTrace` constructor has the following form: + +```cpp +template +ScopedTrace(const char* file, int line, const T& message) +``` + +Example usage: + +```cpp +testing::ScopedTrace trace("file.cc", 123, "message"); +``` + +The resulting trace includes the given source file path and line number, and the +given message. The `message` argument can be anything streamable to +`std::ostream`. + +See also [`SCOPED_TRACE`](#SCOPED_TRACE). + +### Test {#Test} + +`testing::Test` + +The abstract class that all tests inherit from. `Test` is not copyable. + +#### Public Methods {#Test-public} + +##### SetUpTestSuite {#Test::SetUpTestSuite} + +`static void Test::SetUpTestSuite()` + +Performs shared setup for all tests in the test suite. GoogleTest calls +`SetUpTestSuite()` before running the first test in the test suite. + +##### TearDownTestSuite {#Test::TearDownTestSuite} + +`static void Test::TearDownTestSuite()` + +Performs shared teardown for all tests in the test suite. GoogleTest calls +`TearDownTestSuite()` after running the last test in the test suite. + +##### HasFatalFailure {#Test::HasFatalFailure} + +`static bool Test::HasFatalFailure()` + +Returns true if and only if the current test has a fatal failure. + +##### HasNonfatalFailure {#Test::HasNonfatalFailure} + +`static bool Test::HasNonfatalFailure()` + +Returns true if and only if the current test has a nonfatal failure. + +##### HasFailure {#Test::HasFailure} + +`static bool Test::HasFailure()` + +Returns true if and only if the current test has any failure, either fatal or +nonfatal. + +##### IsSkipped {#Test::IsSkipped} + +`static bool Test::IsSkipped()` + +Returns true if and only if the current test was skipped. + +##### RecordProperty {#Test::RecordProperty} + +`static void Test::RecordProperty(const std::string& key, const std::string& +value)` \ +`static void Test::RecordProperty(const std::string& key, int value)` + +Logs a property for the current test, test suite, or entire invocation of the +test program. Only the last value for a given key is logged. + +The key must be a valid XML attribute name, and cannot conflict with the ones +already used by GoogleTest (`name`, `file`, `line`, `status`, `time`, +`classname`, `type_param`, and `value_param`). + +`RecordProperty` is `public static` so it can be called from utility functions +that are not members of the test fixture. + +Calls to `RecordProperty` made during the lifespan of the test (from the moment +its constructor starts to the moment its destructor finishes) are output in XML +as attributes of the `` element. Properties recorded from a fixture's +`SetUpTestSuite` or `TearDownTestSuite` methods are logged as attributes of the +corresponding `` element. Calls to `RecordProperty` made in the +global context (before or after invocation of `RUN_ALL_TESTS` or from the +`SetUp`/`TearDown` methods of registered `Environment` objects) are output as +attributes of the `` element. + +#### Protected Methods {#Test-protected} + +##### SetUp {#Test::SetUp} + +`virtual void Test::SetUp()` + +Override this to perform test fixture setup. GoogleTest calls `SetUp()` before +running each individual test. + +##### TearDown {#Test::TearDown} + +`virtual void Test::TearDown()` + +Override this to perform test fixture teardown. GoogleTest calls `TearDown()` +after running each individual test. + +### TestWithParam {#TestWithParam} + +`testing::TestWithParam` + +A convenience class which inherits from both [`Test`](#Test) and +[`WithParamInterface`](#WithParamInterface). + +### TestSuite {#TestSuite} + +Represents a test suite. `TestSuite` is not copyable. + +#### Public Methods {#TestSuite-public} + +##### name {#TestSuite::name} + +`const char* TestSuite::name() const` + +Gets the name of the test suite. + +##### type_param {#TestSuite::type_param} + +`const char* TestSuite::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test suite. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### should_run {#TestSuite::should_run} + +`bool TestSuite::should_run() const` + +Returns true if any test in this test suite should run. + +##### successful_test_count {#TestSuite::successful_test_count} + +`int TestSuite::successful_test_count() const` + +Gets the number of successful tests in this test suite. + +##### skipped_test_count {#TestSuite::skipped_test_count} + +`int TestSuite::skipped_test_count() const` + +Gets the number of skipped tests in this test suite. + +##### failed_test_count {#TestSuite::failed_test_count} + +`int TestSuite::failed_test_count() const` + +Gets the number of failed tests in this test suite. + +##### reportable_disabled_test_count {#TestSuite::reportable_disabled_test_count} + +`int TestSuite::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#TestSuite::disabled_test_count} + +`int TestSuite::disabled_test_count() const` + +Gets the number of disabled tests in this test suite. + +##### reportable_test_count {#TestSuite::reportable_test_count} + +`int TestSuite::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### test_to_run_count {#TestSuite::test_to_run_count} + +`int TestSuite::test_to_run_count() const` + +Get the number of tests in this test suite that should run. + +##### total_test_count {#TestSuite::total_test_count} + +`int TestSuite::total_test_count() const` + +Gets the number of all tests in this test suite. + +##### Passed {#TestSuite::Passed} + +`bool TestSuite::Passed() const` + +Returns true if and only if the test suite passed. + +##### Failed {#TestSuite::Failed} + +`bool TestSuite::Failed() const` + +Returns true if and only if the test suite failed. + +##### elapsed_time {#TestSuite::elapsed_time} + +`TimeInMillis TestSuite::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestSuite::start_timestamp} + +`TimeInMillis TestSuite::start_timestamp() const` + +Gets the time of the test suite start, in ms from the start of the UNIX epoch. + +##### GetTestInfo {#TestSuite::GetTestInfo} + +`const TestInfo* TestSuite::GetTestInfo(int i) const` + +Returns the [`TestInfo`](#TestInfo) for the `i`-th test among all the tests. `i` +can range from 0 to `total_test_count() - 1`. If `i` is not in that range, +returns `NULL`. + +##### ad_hoc_test_result {#TestSuite::ad_hoc_test_result} + +`const TestResult& TestSuite::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) that holds test properties recorded +during execution of `SetUpTestSuite` and `TearDownTestSuite`. + +### TestInfo {#TestInfo} + +`testing::TestInfo` + +Stores information about a test. + +#### Public Methods {#TestInfo-public} + +##### test_suite_name {#TestInfo::test_suite_name} + +`const char* TestInfo::test_suite_name() const` + +Returns the test suite name. + +##### name {#TestInfo::name} + +`const char* TestInfo::name() const` + +Returns the test name. + +##### type_param {#TestInfo::type_param} + +`const char* TestInfo::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### value_param {#TestInfo::value_param} + +`const char* TestInfo::value_param() const` + +Returns the text representation of the value parameter, or `NULL` if this is not +a value-parameterized test. See +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +##### file {#TestInfo::file} + +`const char* TestInfo::file() const` + +Returns the file name where this test is defined. + +##### line {#TestInfo::line} + +`int TestInfo::line() const` + +Returns the line where this test is defined. + +##### is_in_another_shard {#TestInfo::is_in_another_shard} + +`bool TestInfo::is_in_another_shard() const` + +Returns true if this test should not be run because it's in another shard. + +##### should_run {#TestInfo::should_run} + +`bool TestInfo::should_run() const` + +Returns true if this test should run, that is if the test is not disabled (or it +is disabled but the `also_run_disabled_tests` flag has been specified) and its +full name matches the user-specified filter. + +GoogleTest allows the user to filter the tests by their full names. Only the +tests that match the filter will run. See +[Running a Subset of the Tests](../advanced.md#running-a-subset-of-the-tests) +for more information. + +##### is_reportable {#TestInfo::is_reportable} + +`bool TestInfo::is_reportable() const` + +Returns true if and only if this test will appear in the XML report. + +##### result {#TestInfo::result} + +`const TestResult* TestInfo::result() const` + +Returns the result of the test. See [`TestResult`](#TestResult). + +### TestParamInfo {#TestParamInfo} + +`testing::TestParamInfo` + +Describes a parameter to a value-parameterized test. The type `T` is the type of +the parameter. + +Contains the fields `param` and `index` which hold the value of the parameter +and its integer index respectively. + +### UnitTest {#UnitTest} + +`testing::UnitTest` + +This class contains information about the test program. + +`UnitTest` is a singleton class. The only instance is created when +`UnitTest::GetInstance()` is first called. This instance is never deleted. + +`UnitTest` is not copyable. + +#### Public Methods {#UnitTest-public} + +##### GetInstance {#UnitTest::GetInstance} + +`static UnitTest* UnitTest::GetInstance()` + +Gets the singleton `UnitTest` object. The first time this method is called, a +`UnitTest` object is constructed and returned. Consecutive calls will return the +same object. + +##### original_working_dir {#UnitTest::original_working_dir} + +`const char* UnitTest::original_working_dir() const` + +Returns the working directory when the first [`TEST()`](#TEST) or +[`TEST_F()`](#TEST_F) was executed. The `UnitTest` object owns the string. + +##### current_test_suite {#UnitTest::current_test_suite} + +`const TestSuite* UnitTest::current_test_suite() const` + +Returns the [`TestSuite`](#TestSuite) object for the test that's currently +running, or `NULL` if no test is running. + +##### current_test_info {#UnitTest::current_test_info} + +`const TestInfo* UnitTest::current_test_info() const` + +Returns the [`TestInfo`](#TestInfo) object for the test that's currently +running, or `NULL` if no test is running. + +##### random_seed {#UnitTest::random_seed} + +`int UnitTest::random_seed() const` + +Returns the random seed used at the start of the current test run. + +##### successful_test_suite_count {#UnitTest::successful_test_suite_count} + +`int UnitTest::successful_test_suite_count() const` + +Gets the number of successful test suites. + +##### failed_test_suite_count {#UnitTest::failed_test_suite_count} + +`int UnitTest::failed_test_suite_count() const` + +Gets the number of failed test suites. + +##### total_test_suite_count {#UnitTest::total_test_suite_count} + +`int UnitTest::total_test_suite_count() const` + +Gets the number of all test suites. + +##### test_suite_to_run_count {#UnitTest::test_suite_to_run_count} + +`int UnitTest::test_suite_to_run_count() const` + +Gets the number of all test suites that contain at least one test that should +run. + +##### successful_test_count {#UnitTest::successful_test_count} + +`int UnitTest::successful_test_count() const` + +Gets the number of successful tests. + +##### skipped_test_count {#UnitTest::skipped_test_count} + +`int UnitTest::skipped_test_count() const` + +Gets the number of skipped tests. + +##### failed_test_count {#UnitTest::failed_test_count} + +`int UnitTest::failed_test_count() const` + +Gets the number of failed tests. + +##### reportable_disabled_test_count {#UnitTest::reportable_disabled_test_count} + +`int UnitTest::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#UnitTest::disabled_test_count} + +`int UnitTest::disabled_test_count() const` + +Gets the number of disabled tests. + +##### reportable_test_count {#UnitTest::reportable_test_count} + +`int UnitTest::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### total_test_count {#UnitTest::total_test_count} + +`int UnitTest::total_test_count() const` + +Gets the number of all tests. + +##### test_to_run_count {#UnitTest::test_to_run_count} + +`int UnitTest::test_to_run_count() const` + +Gets the number of tests that should run. + +##### start_timestamp {#UnitTest::start_timestamp} + +`TimeInMillis UnitTest::start_timestamp() const` + +Gets the time of the test program start, in ms from the start of the UNIX epoch. + +##### elapsed_time {#UnitTest::elapsed_time} + +`TimeInMillis UnitTest::elapsed_time() const` + +Gets the elapsed time, in milliseconds. + +##### Passed {#UnitTest::Passed} + +`bool UnitTest::Passed() const` + +Returns true if and only if the unit test passed (i.e. all test suites passed). + +##### Failed {#UnitTest::Failed} + +`bool UnitTest::Failed() const` + +Returns true if and only if the unit test failed (i.e. some test suite failed or +something outside of all tests failed). + +##### GetTestSuite {#UnitTest::GetTestSuite} + +`const TestSuite* UnitTest::GetTestSuite(int i) const` + +Gets the [`TestSuite`](#TestSuite) object for the `i`-th test suite among all +the test suites. `i` can range from 0 to `total_test_suite_count() - 1`. If `i` +is not in that range, returns `NULL`. + +##### ad_hoc_test_result {#UnitTest::ad_hoc_test_result} + +`const TestResult& UnitTest::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) containing information on test failures +and properties logged outside of individual test suites. + +##### listeners {#UnitTest::listeners} + +`TestEventListeners& UnitTest::listeners()` + +Returns the list of event listeners that can be used to track events inside +GoogleTest. See [`TestEventListeners`](#TestEventListeners). + +### TestEventListener {#TestEventListener} + +`testing::TestEventListener` + +The interface for tracing execution of tests. The methods below are listed in +the order the corresponding events are fired. + +#### Public Methods {#TestEventListener-public} + +##### OnTestProgramStart {#TestEventListener::OnTestProgramStart} + +`virtual void TestEventListener::OnTestProgramStart(const UnitTest& unit_test)` + +Fired before any test activity starts. + +##### OnTestIterationStart {#TestEventListener::OnTestIterationStart} + +`virtual void TestEventListener::OnTestIterationStart(const UnitTest& unit_test, +int iteration)` + +Fired before each iteration of tests starts. There may be more than one +iteration if `GTEST_FLAG(repeat)` is set. `iteration` is the iteration index, +starting from 0. + +##### OnEnvironmentsSetUpStart {#TestEventListener::OnEnvironmentsSetUpStart} + +`virtual void TestEventListener::OnEnvironmentsSetUpStart(const UnitTest& +unit_test)` + +Fired before environment set-up for each iteration of tests starts. + +##### OnEnvironmentsSetUpEnd {#TestEventListener::OnEnvironmentsSetUpEnd} + +`virtual void TestEventListener::OnEnvironmentsSetUpEnd(const UnitTest& +unit_test)` + +Fired after environment set-up for each iteration of tests ends. + +##### OnTestSuiteStart {#TestEventListener::OnTestSuiteStart} + +`virtual void TestEventListener::OnTestSuiteStart(const TestSuite& test_suite)` + +Fired before the test suite starts. + +##### OnTestStart {#TestEventListener::OnTestStart} + +`virtual void TestEventListener::OnTestStart(const TestInfo& test_info)` + +Fired before the test starts. + +##### OnTestPartResult {#TestEventListener::OnTestPartResult} + +`virtual void TestEventListener::OnTestPartResult(const TestPartResult& +test_part_result)` + +Fired after a failed assertion or a `SUCCEED()` invocation. If you want to throw +an exception from this function to skip to the next test, it must be an +[`AssertionException`](#AssertionException) or inherited from it. + +##### OnTestEnd {#TestEventListener::OnTestEnd} + +`virtual void TestEventListener::OnTestEnd(const TestInfo& test_info)` + +Fired after the test ends. + +##### OnTestSuiteEnd {#TestEventListener::OnTestSuiteEnd} + +`virtual void TestEventListener::OnTestSuiteEnd(const TestSuite& test_suite)` + +Fired after the test suite ends. + +##### OnEnvironmentsTearDownStart {#TestEventListener::OnEnvironmentsTearDownStart} + +`virtual void TestEventListener::OnEnvironmentsTearDownStart(const UnitTest& +unit_test)` + +Fired before environment tear-down for each iteration of tests starts. + +##### OnEnvironmentsTearDownEnd {#TestEventListener::OnEnvironmentsTearDownEnd} + +`virtual void TestEventListener::OnEnvironmentsTearDownEnd(const UnitTest& +unit_test)` + +Fired after environment tear-down for each iteration of tests ends. + +##### OnTestIterationEnd {#TestEventListener::OnTestIterationEnd} + +`virtual void TestEventListener::OnTestIterationEnd(const UnitTest& unit_test, +int iteration)` + +Fired after each iteration of tests finishes. + +##### OnTestProgramEnd {#TestEventListener::OnTestProgramEnd} + +`virtual void TestEventListener::OnTestProgramEnd(const UnitTest& unit_test)` + +Fired after all test activities have ended. + +### TestEventListeners {#TestEventListeners} + +`testing::TestEventListeners` + +Lets users add listeners to track events in GoogleTest. + +#### Public Methods {#TestEventListeners-public} + +##### Append {#TestEventListeners::Append} + +`void TestEventListeners::Append(TestEventListener* listener)` + +Appends an event listener to the end of the list. GoogleTest assumes ownership +of the listener (i.e. it will delete the listener when the test program +finishes). + +##### Release {#TestEventListeners::Release} + +`TestEventListener* TestEventListeners::Release(TestEventListener* listener)` + +Removes the given event listener from the list and returns it. It then becomes +the caller's responsibility to delete the listener. Returns `NULL` if the +listener is not found in the list. + +##### default_result_printer {#TestEventListeners::default_result_printer} + +`TestEventListener* TestEventListeners::default_result_printer() const` + +Returns the standard listener responsible for the default console output. Can be +removed from the listeners list to shut down default console output. Note that +removing this object from the listener list with +[`Release()`](#TestEventListeners::Release) transfers its ownership to the +caller and makes this function return `NULL` the next time. + +##### default_xml_generator {#TestEventListeners::default_xml_generator} + +`TestEventListener* TestEventListeners::default_xml_generator() const` + +Returns the standard listener responsible for the default XML output controlled +by the `--gtest_output=xml` flag. Can be removed from the listeners list by +users who want to shut down the default XML output controlled by this flag and +substitute it with custom one. Note that removing this object from the listener +list with [`Release()`](#TestEventListeners::Release) transfers its ownership to +the caller and makes this function return `NULL` the next time. + +### TestPartResult {#TestPartResult} + +`testing::TestPartResult` + +A copyable object representing the result of a test part (i.e. an assertion or +an explicit `FAIL()`, `ADD_FAILURE()`, or `SUCCESS()`). + +#### Public Methods {#TestPartResult-public} + +##### type {#TestPartResult::type} + +`Type TestPartResult::type() const` + +Gets the outcome of the test part. + +The return type `Type` is an enum defined as follows: + +```cpp +enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure, // Failed and the test should be terminated. + kSkip // Skipped. +}; +``` + +##### file_name {#TestPartResult::file_name} + +`const char* TestPartResult::file_name() const` + +Gets the name of the source file where the test part took place, or `NULL` if +it's unknown. + +##### line_number {#TestPartResult::line_number} + +`int TestPartResult::line_number() const` + +Gets the line in the source file where the test part took place, or `-1` if it's +unknown. + +##### summary {#TestPartResult::summary} + +`const char* TestPartResult::summary() const` + +Gets the summary of the failure message. + +##### message {#TestPartResult::message} + +`const char* TestPartResult::message() const` + +Gets the message associated with the test part. + +##### skipped {#TestPartResult::skipped} + +`bool TestPartResult::skipped() const` + +Returns true if and only if the test part was skipped. + +##### passed {#TestPartResult::passed} + +`bool TestPartResult::passed() const` + +Returns true if and only if the test part passed. + +##### nonfatally_failed {#TestPartResult::nonfatally_failed} + +`bool TestPartResult::nonfatally_failed() const` + +Returns true if and only if the test part non-fatally failed. + +##### fatally_failed {#TestPartResult::fatally_failed} + +`bool TestPartResult::fatally_failed() const` + +Returns true if and only if the test part fatally failed. + +##### failed {#TestPartResult::failed} + +`bool TestPartResult::failed() const` + +Returns true if and only if the test part failed. + +### TestProperty {#TestProperty} + +`testing::TestProperty` + +A copyable object representing a user-specified test property which can be +output as a key/value string pair. + +#### Public Methods {#TestProperty-public} + +##### key {#key} + +`const char* key() const` + +Gets the user-supplied key. + +##### value {#value} + +`const char* value() const` + +Gets the user-supplied value. + +##### SetValue {#SetValue} + +`void SetValue(const std::string& new_value)` + +Sets a new value, overriding the previous one. + +### TestResult {#TestResult} + +`testing::TestResult` + +Contains information about the result of a single test. + +`TestResult` is not copyable. + +#### Public Methods {#TestResult-public} + +##### total_part_count {#TestResult::total_part_count} + +`int TestResult::total_part_count() const` + +Gets the number of all test parts. This is the sum of the number of successful +test parts and the number of failed test parts. + +##### test_property_count {#TestResult::test_property_count} + +`int TestResult::test_property_count() const` + +Returns the number of test properties. + +##### Passed {#TestResult::Passed} + +`bool TestResult::Passed() const` + +Returns true if and only if the test passed (i.e. no test part failed). + +##### Skipped {#TestResult::Skipped} + +`bool TestResult::Skipped() const` + +Returns true if and only if the test was skipped. + +##### Failed {#TestResult::Failed} + +`bool TestResult::Failed() const` + +Returns true if and only if the test failed. + +##### HasFatalFailure {#TestResult::HasFatalFailure} + +`bool TestResult::HasFatalFailure() const` + +Returns true if and only if the test fatally failed. + +##### HasNonfatalFailure {#TestResult::HasNonfatalFailure} + +`bool TestResult::HasNonfatalFailure() const` + +Returns true if and only if the test has a non-fatal failure. + +##### elapsed_time {#TestResult::elapsed_time} + +`TimeInMillis TestResult::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestResult::start_timestamp} + +`TimeInMillis TestResult::start_timestamp() const` + +Gets the time of the test case start, in ms from the start of the UNIX epoch. + +##### GetTestPartResult {#TestResult::GetTestPartResult} + +`const TestPartResult& TestResult::GetTestPartResult(int i) const` + +Returns the [`TestPartResult`](#TestPartResult) for the `i`-th test part result +among all the results. `i` can range from 0 to `total_part_count() - 1`. If `i` +is not in that range, aborts the program. + +##### GetTestProperty {#TestResult::GetTestProperty} + +`const TestProperty& TestResult::GetTestProperty(int i) const` + +Returns the [`TestProperty`](#TestProperty) object for the `i`-th test property. +`i` can range from 0 to `test_property_count() - 1`. If `i` is not in that +range, aborts the program. + +### TimeInMillis {#TimeInMillis} + +`testing::TimeInMillis` + +An integer type representing time in milliseconds. + +### Types {#Types} + +`testing::Types` + +Represents a list of types for use in typed tests and type-parameterized tests. + +The template argument `T...` can be any number of types, for example: + +``` +testing::Types +``` + +See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### WithParamInterface {#WithParamInterface} + +`testing::WithParamInterface` + +The pure interface class that all value-parameterized tests inherit from. + +A value-parameterized test fixture class must inherit from both [`Test`](#Test) +and `WithParamInterface`. In most cases that just means inheriting from +[`TestWithParam`](#TestWithParam), but more complicated test hierarchies may +need to inherit from `Test` and `WithParamInterface` at different levels. + +This interface defines the type alias `ParamType` for the parameter type `T` and +has support for accessing the test parameter value via the `GetParam()` method: + +``` +static const ParamType& GetParam() +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +## Functions + +GoogleTest defines the following functions to help with writing and running +tests. + +### InitGoogleTest {#InitGoogleTest} + +`void testing::InitGoogleTest(int* argc, char** argv)` \ +`void testing::InitGoogleTest(int* argc, wchar_t** argv)` \ +`void testing::InitGoogleTest()` + +Initializes GoogleTest. This must be called before calling +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS). In particular, it parses the command line +for the flags that GoogleTest recognizes. Whenever a GoogleTest flag is seen, it +is removed from `argv`, and `*argc` is decremented. + +No value is returned. Instead, the GoogleTest flag variables are updated. + +The `InitGoogleTest(int* argc, wchar_t** argv)` overload can be used in Windows +programs compiled in `UNICODE` mode. + +The argument-less `InitGoogleTest()` overload can be used on Arduino/embedded +platforms where there is no `argc`/`argv`. + +### AddGlobalTestEnvironment {#AddGlobalTestEnvironment} + +`Environment* testing::AddGlobalTestEnvironment(Environment* env)` + +Adds a test environment to the test program. Must be called before +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is called. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down) for +more information. + +See also [`Environment`](#Environment). + +### RegisterTest {#RegisterTest} + +```cpp +template +TestInfo* testing::RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory) +``` + +Dynamically registers a test with the framework. + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the `Test` object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` methods for it. + +Must be called before [`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is invoked, otherwise +behavior is undefined. + +See +[Registering tests programmatically](../advanced.md#registering-tests-programmatically) +for more information. + +### RUN_ALL_TESTS {#RUN_ALL_TESTS} + +`int RUN_ALL_TESTS()` + +Use this function in `main()` to run all tests. It returns `0` if all tests are +successful, or `1` otherwise. + +`RUN_ALL_TESTS()` should be invoked after the command line has been parsed by +[`InitGoogleTest()`](#InitGoogleTest). + +This function was formerly a macro; thus, it is in the global namespace and has +an all-caps name. + +### AssertionSuccess {#AssertionSuccess} + +`AssertionResult testing::AssertionSuccess()` + +Creates a successful assertion result. See +[`AssertionResult`](#AssertionResult). + +### AssertionFailure {#AssertionFailure} + +`AssertionResult testing::AssertionFailure()` + +Creates a failed assertion result. Use the `<<` operator to store a failure +message: + +```cpp +testing::AssertionFailure() << "My failure message"; +``` + +See [`AssertionResult`](#AssertionResult). + +### StaticAssertTypeEq {#StaticAssertTypeEq} + +`testing::StaticAssertTypeEq()` + +Compile-time assertion for type equality. Compiles if and only if `T1` and `T2` +are the same type. The value it returns is irrelevant. + +See [Type Assertions](../advanced.md#type-assertions) for more information. + +### PrintToString {#PrintToString} + +`std::string testing::PrintToString(x)` + +Prints any value `x` using GoogleTest's value printer. + +See +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values) +for more information. + +### PrintToStringParamName {#PrintToStringParamName} + +`std::string testing::PrintToStringParamName(TestParamInfo& info)` + +A built-in parameterized test name generator which returns the result of +[`PrintToString`](#PrintToString) called on `info.param`. Does not work when the +test parameter is a `std::string` or C string. See +[Specifying Names for Value-Parameterized Test Parameters](../advanced.md#specifying-names-for-value-parameterized-test-parameters) +for more information. + +See also [`TestParamInfo`](#TestParamInfo) and +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). diff --git a/third_party/googletest/docs/samples.md b/third_party/googletest/docs/samples.md new file mode 100644 index 0000000..dedc590 --- /dev/null +++ b/third_party/googletest/docs/samples.md @@ -0,0 +1,22 @@ +# Googletest Samples + +If you're like us, you'd like to look at +[googletest samples.](https://github.com/google/googletest/blob/main/googletest/samples) +The sample directory has a number of well-commented samples showing how to use a +variety of googletest features. + +* Sample #1 shows the basic steps of using googletest to test C++ functions. +* Sample #2 shows a more complex unit test for a class with multiple member + functions. +* Sample #3 uses a test fixture. +* Sample #4 teaches you how to use googletest and `googletest.h` together to + get the best of both libraries. +* Sample #5 puts shared testing logic in a base test fixture, and reuses it in + derived fixtures. +* Sample #6 demonstrates type-parameterized tests. +* Sample #7 teaches the basics of value-parameterized tests. +* Sample #8 shows using `Combine()` in value-parameterized tests. +* Sample #9 shows use of the listener API to modify Google Test's console + output and the use of its reflection API to inspect test results. +* Sample #10 shows use of the listener API to implement a primitive memory + leak checker. diff --git a/third_party/googletest/googlemock/CMakeLists.txt b/third_party/googletest/googlemock/CMakeLists.txt new file mode 100644 index 0000000..a9aa072 --- /dev/null +++ b/third_party/googletest/googlemock/CMakeLists.txt @@ -0,0 +1,209 @@ +######################################################################## +# Note: CMake support is community-based. The maintainers do not use CMake +# internally. +# +# CMake build script for Google Mock. +# +# To run the tests for Google Mock itself on Linux, use 'make test' or +# ctest. You can select which tests to run using 'ctest -R regex'. +# For more options, run 'ctest --help'. + +option(gmock_build_tests "Build all of Google Mock's own tests." OFF) + +# A directory to find Google Test sources. +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/gtest/CMakeLists.txt") + set(gtest_dir gtest) +else() + set(gtest_dir ../googletest) +endif() + +# Defines pre_project_set_up_hermetic_build() and set_up_hermetic_build(). +include("${gtest_dir}/cmake/hermetic_build.cmake" OPTIONAL) + +if (COMMAND pre_project_set_up_hermetic_build) + # Google Test also calls hermetic setup functions from add_subdirectory, + # although its changes will not affect things at the current scope. + pre_project_set_up_hermetic_build() +endif() + +######################################################################## +# +# Project-wide settings + +# Name of the project. +# +# CMake files in this project can refer to the root source directory +# as ${gmock_SOURCE_DIR} and to the root binary directory as +# ${gmock_BINARY_DIR}. +# Language "C" is required for find_package(Threads). +cmake_minimum_required(VERSION 3.13) +project(gmock VERSION ${GOOGLETEST_VERSION} LANGUAGES CXX C) + +if (COMMAND set_up_hermetic_build) + set_up_hermetic_build() +endif() + +# Instructs CMake to process Google Test's CMakeLists.txt and add its +# targets to the current scope. We are placing Google Test's binary +# directory in a subdirectory of our own as VC compilation may break +# if they are the same (the default). +add_subdirectory("${gtest_dir}" "${gmock_BINARY_DIR}/${gtest_dir}") + + +# These commands only run if this is the main project +if(CMAKE_PROJECT_NAME STREQUAL "gmock" OR CMAKE_PROJECT_NAME STREQUAL "googletest-distribution") + # BUILD_SHARED_LIBS is a standard CMake variable, but we declare it here to + # make it prominent in the GUI. + option(BUILD_SHARED_LIBS "Build shared libraries (DLLs)." OFF) +else() + mark_as_advanced(gmock_build_tests) +endif() + +# Although Google Test's CMakeLists.txt calls this function, the +# changes there don't affect the current scope. Therefore we have to +# call it again here. +config_compiler_and_linker() # from ${gtest_dir}/cmake/internal_utils.cmake + +# Adds Google Mock's and Google Test's header directories to the search path. +set(gmock_build_include_dirs + "${gmock_SOURCE_DIR}/include" + "${gmock_SOURCE_DIR}" + "${gtest_SOURCE_DIR}/include" + # This directory is needed to build directly from Google Test sources. + "${gtest_SOURCE_DIR}") +include_directories(${gmock_build_include_dirs}) + +######################################################################## +# +# Defines the gmock & gmock_main libraries. User tests should link +# with one of them. + +# Google Mock libraries. We build them using more strict warnings than what +# are used for other targets, to ensure that Google Mock can be compiled by +# a user aggressive about warnings. +if (MSVC) + cxx_library(gmock + "${cxx_strict}" + "${gtest_dir}/src/gtest-all.cc" + src/gmock-all.cc) + + cxx_library(gmock_main + "${cxx_strict}" + "${gtest_dir}/src/gtest-all.cc" + src/gmock-all.cc + src/gmock_main.cc) +else() + cxx_library(gmock "${cxx_strict}" src/gmock-all.cc) + target_link_libraries(gmock PUBLIC gtest) + set_target_properties(gmock PROPERTIES VERSION ${GOOGLETEST_VERSION}) + cxx_library(gmock_main "${cxx_strict}" src/gmock_main.cc) + target_link_libraries(gmock_main PUBLIC gmock) + set_target_properties(gmock_main PROPERTIES VERSION ${GOOGLETEST_VERSION}) +endif() + +string(REPLACE ";" "$" dirs "${gmock_build_include_dirs}") +target_include_directories(gmock SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") +target_include_directories(gmock_main SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") + +######################################################################## +# +# Install rules +install_project(gmock gmock_main) + +######################################################################## +# +# Google Mock's own tests. +# +# You can skip this section if you aren't interested in testing +# Google Mock itself. +# +# The tests are not built by default. To build them, set the +# gmock_build_tests option to ON. You can do it by running ccmake +# or specifying the -Dgmock_build_tests=ON flag when running cmake. + +if (gmock_build_tests) + # This must be set in the root directory for the tests to be run by + # 'make test' or ctest. + enable_testing() + + if (MINGW OR CYGWIN) + add_compile_options("-Wa,-mbig-obj") + endif() + + ############################################################ + # C++ tests built with standard compiler flags. + + cxx_test(gmock-actions_test gmock_main) + cxx_test(gmock-cardinalities_test gmock_main) + cxx_test(gmock_ex_test gmock_main) + cxx_test(gmock-function-mocker_test gmock_main) + cxx_test(gmock-internal-utils_test gmock_main) + cxx_test(gmock-matchers-arithmetic_test gmock_main) + cxx_test(gmock-matchers-comparisons_test gmock_main) + cxx_test(gmock-matchers-containers_test gmock_main) + cxx_test(gmock-matchers-misc_test gmock_main) + cxx_test(gmock-more-actions_test gmock_main) + cxx_test(gmock-nice-strict_test gmock_main) + cxx_test(gmock-port_test gmock_main) + cxx_test(gmock-spec-builders_test gmock_main) + cxx_test(gmock_link_test gmock_main test/gmock_link2_test.cc) + cxx_test(gmock_test gmock_main) + + if (DEFINED GTEST_HAS_PTHREAD) + cxx_test(gmock_stress_test gmock) + endif() + + # gmock_all_test is commented to save time building and running tests. + # Uncomment if necessary. + # cxx_test(gmock_all_test gmock_main) + + ############################################################ + # C++ tests built with non-standard compiler flags. + + if (MSVC) + cxx_library(gmock_main_no_exception "${cxx_no_exception}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + cxx_library(gmock_main_no_rtti "${cxx_no_rtti}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + else() + cxx_library(gmock_main_no_exception "${cxx_no_exception}" src/gmock_main.cc) + target_link_libraries(gmock_main_no_exception PUBLIC gmock) + + cxx_library(gmock_main_no_rtti "${cxx_no_rtti}" src/gmock_main.cc) + target_link_libraries(gmock_main_no_rtti PUBLIC gmock) + endif() + cxx_test_with_flags(gmock-more-actions_no_exception_test "${cxx_no_exception}" + gmock_main_no_exception test/gmock-more-actions_test.cc) + + cxx_test_with_flags(gmock_no_rtti_test "${cxx_no_rtti}" + gmock_main_no_rtti test/gmock-spec-builders_test.cc) + + cxx_shared_library(shared_gmock_main "${cxx_default}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + # Tests that a binary can be built with Google Mock as a shared library. On + # some system configurations, it may not possible to run the binary without + # knowing more details about the system configurations. We do not try to run + # this binary. To get a more robust shared library coverage, configure with + # -DBUILD_SHARED_LIBS=ON. + cxx_executable_with_flags(shared_gmock_test_ "${cxx_default}" + shared_gmock_main test/gmock-spec-builders_test.cc) + set_target_properties(shared_gmock_test_ + PROPERTIES + COMPILE_DEFINITIONS "GTEST_LINKED_AS_SHARED_LIBRARY=1") + + ############################################################ + # Python tests. + + cxx_executable(gmock_leak_test_ test gmock_main) + py_test(gmock_leak_test) + + cxx_executable(gmock_output_test_ test gmock) + py_test(gmock_output_test) +endif() diff --git a/third_party/googletest/googlemock/README.md b/third_party/googletest/googlemock/README.md new file mode 100644 index 0000000..7da6065 --- /dev/null +++ b/third_party/googletest/googlemock/README.md @@ -0,0 +1,40 @@ +# Googletest Mocking (gMock) Framework + +### Overview + +Google's framework for writing and using C++ mock classes. It can help you +derive better designs of your system and write better tests. + +It is inspired by: + +* [jMock](http://www.jmock.org/) +* [EasyMock](http://www.easymock.org/) +* [Hamcrest](http://code.google.com/p/hamcrest/) + +It is designed with C++'s specifics in mind. + +gMock: + +- Provides a declarative syntax for defining mocks. +- Can define partial (hybrid) mocks, which are a cross of real and mock + objects. +- Handles functions of arbitrary types and overloaded functions. +- Comes with a rich set of matchers for validating function arguments. +- Uses an intuitive syntax for controlling the behavior of a mock. +- Does automatic verification of expectations (no record-and-replay needed). +- Allows arbitrary (partial) ordering constraints on function calls to be + expressed. +- Lets a user extend it by defining new matchers and actions. +- Does not use exceptions. +- Is easy to learn and use. + +Details and examples can be found here: + +* [gMock for Dummies](https://google.github.io/googletest/gmock_for_dummies.html) +* [Legacy gMock FAQ](https://google.github.io/googletest/gmock_faq.html) +* [gMock Cookbook](https://google.github.io/googletest/gmock_cook_book.html) +* [gMock Cheat Sheet](https://google.github.io/googletest/gmock_cheat_sheet.html) + +GoogleMock is a part of +[GoogleTest C++ testing framework](http://github.com/google/googletest/) and a +subject to the same requirements. diff --git a/third_party/googletest/googlemock/cmake/gmock.pc.in b/third_party/googletest/googlemock/cmake/gmock.pc.in new file mode 100644 index 0000000..23c67b5 --- /dev/null +++ b/third_party/googletest/googlemock/cmake/gmock.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock +Description: GoogleMock (without main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gtest = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/third_party/googletest/googlemock/cmake/gmock_main.pc.in b/third_party/googletest/googlemock/cmake/gmock_main.pc.in new file mode 100644 index 0000000..66ffea7 --- /dev/null +++ b/third_party/googletest/googlemock/cmake/gmock_main.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock_main +Description: GoogleMock (with main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gmock = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock_main @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/third_party/googletest/googlemock/docs/README.md b/third_party/googletest/googlemock/docs/README.md new file mode 100644 index 0000000..1bc57b7 --- /dev/null +++ b/third_party/googletest/googlemock/docs/README.md @@ -0,0 +1,4 @@ +# Content Moved + +We are working on updates to the GoogleTest documentation, which has moved to +the top-level [docs](../../docs) directory. diff --git a/third_party/googletest/googlemock/include/gmock/gmock-actions.h b/third_party/googletest/googlemock/include/gmock/gmock-actions.h new file mode 100644 index 0000000..bd9ba73 --- /dev/null +++ b/third_party/googletest/googlemock/include/gmock/gmock-actions.h @@ -0,0 +1,2297 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The ACTION* family of macros can be used in a namespace scope to +// define custom actions easily. The syntax: +// +// ACTION(name) { statements; } +// +// will define an action with the given name that executes the +// statements. The value returned by the statements will be used as +// the return value of the action. Inside the statements, you can +// refer to the K-th (0-based) argument of the mock function by +// 'argK', and refer to its type by 'argK_type'. For example: +// +// ACTION(IncrementArg1) { +// arg1_type temp = arg1; +// return ++(*temp); +// } +// +// allows you to write +// +// ...WillOnce(IncrementArg1()); +// +// You can also refer to the entire argument tuple and its type by +// 'args' and 'args_type', and refer to the mock function type and its +// return type by 'function_type' and 'return_type'. +// +// Note that you don't need to specify the types of the mock function +// arguments. However rest assured that your code is still type-safe: +// you'll get a compiler error if *arg1 doesn't support the ++ +// operator, or if the type of ++(*arg1) isn't compatible with the +// mock function's return type, for example. +// +// Sometimes you'll want to parameterize the action. For that you can use +// another macro: +// +// ACTION_P(name, param_name) { statements; } +// +// For example: +// +// ACTION_P(Add, n) { return arg0 + n; } +// +// will allow you to write: +// +// ...WillOnce(Add(5)); +// +// Note that you don't need to provide the type of the parameter +// either. If you need to reference the type of a parameter named +// 'foo', you can write 'foo_type'. For example, in the body of +// ACTION_P(Add, n) above, you can write 'n_type' to refer to the type +// of 'n'. +// +// We also provide ACTION_P2, ACTION_P3, ..., up to ACTION_P10 to support +// multi-parameter actions. +// +// For the purpose of typing, you can view +// +// ACTION_Pk(Foo, p1, ..., pk) { ... } +// +// as shorthand for +// +// template +// FooActionPk Foo(p1_type p1, ..., pk_type pk) { ... } +// +// In particular, you can provide the template type arguments +// explicitly when invoking Foo(), as in Foo(5, false); +// although usually you can rely on the compiler to infer the types +// for you automatically. You can assign the result of expression +// Foo(p1, ..., pk) to a variable of type FooActionPk. This can be useful when composing actions. +// +// You can also overload actions with different numbers of parameters: +// +// ACTION_P(Plus, a) { ... } +// ACTION_P2(Plus, a, b) { ... } +// +// While it's tempting to always use the ACTION* macros when defining +// a new action, you should also consider implementing ActionInterface +// or using MakePolymorphicAction() instead, especially if you need to +// use the action a lot. While these approaches require more work, +// they give you more control on the types of the mock function +// arguments and the action parameters, which in general leads to +// better compiler error messages that pay off in the long run. They +// also allow overloading actions based on parameter types (as opposed +// to just based on the number of parameters). +// +// CAVEAT: +// +// ACTION*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// Users can, however, define any local functors (e.g. a lambda) that +// can be used as actions. +// +// MORE INFORMATION: +// +// To learn more about using these macros, please search for 'ACTION' on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ + +#ifndef _WIN32_WCE +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4100) + +namespace testing { + +// To implement an action Foo, define: +// 1. a class FooAction that implements the ActionInterface interface, and +// 2. a factory function that creates an Action object from a +// const FooAction*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Action objects can now be copied like plain values. + +namespace internal { + +// BuiltInDefaultValueGetter::Get() returns a +// default-constructed T value. BuiltInDefaultValueGetter::Get() crashes with an error. +// +// This primary template is used when kDefaultConstructible is true. +template +struct BuiltInDefaultValueGetter { + static T Get() { return T(); } +}; +template +struct BuiltInDefaultValueGetter { + static T Get() { + Assert(false, __FILE__, __LINE__, + "Default action undefined for the function return type."); + return internal::Invalid(); + // The above statement will never be reached, but is required in + // order for this function to compile. + } +}; + +// BuiltInDefaultValue::Get() returns the "built-in" default value +// for type T, which is NULL when T is a raw pointer type, 0 when T is +// a numeric type, false when T is bool, or "" when T is string or +// std::string. In addition, in C++11 and above, it turns a +// default-constructed T value if T is default constructible. For any +// other type T, the built-in default T value is undefined, and the +// function will abort the process. +template +class BuiltInDefaultValue { + public: + // This function returns true if and only if type T has a built-in default + // value. + static bool Exists() { return ::std::is_default_constructible::value; } + + static T Get() { + return BuiltInDefaultValueGetter< + T, ::std::is_default_constructible::value>::Get(); + } +}; + +// This partial specialization says that we use the same built-in +// default value for T and const T. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return BuiltInDefaultValue::Exists(); } + static T Get() { return BuiltInDefaultValue::Get(); } +}; + +// This partial specialization defines the default values for pointer +// types. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return true; } + static T* Get() { return nullptr; } +}; + +// The following specializations define the default values for +// specific types we care about. +#define GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(type, value) \ + template <> \ + class BuiltInDefaultValue { \ + public: \ + static bool Exists() { return true; } \ + static type Get() { return value; } \ + } + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(void, ); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(::std::string, ""); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(bool, false); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(char, '\0'); + +// There's no need for a default action for signed wchar_t, as that +// type is the same as wchar_t for gcc, and invalid for MSVC. +// +// There's also no need for a default action for unsigned wchar_t, as +// that type is the same as unsigned int for gcc, and invalid for +// MSVC. +#if GMOCK_WCHAR_T_IS_NATIVE_ +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(wchar_t, 0U); // NOLINT +#endif + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned short, 0U); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed short, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned int, 0U); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed int, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long, 0UL); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long, 0L); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(float, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(double, 0); + +#undef GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_ + +// Partial implementations of metaprogramming types from the standard library +// not available in C++11. + +template +struct negation + // NOLINTNEXTLINE + : std::integral_constant {}; + +// Base case: with zero predicates the answer is always true. +template +struct conjunction : std::true_type {}; + +// With a single predicate, the answer is that predicate. +template +struct conjunction : P1 {}; + +// With multiple predicates the answer is the first predicate if that is false, +// and we recurse otherwise. +template +struct conjunction + : std::conditional, P1>::type {}; + +template +struct disjunction : std::false_type {}; + +template +struct disjunction : P1 {}; + +template +struct disjunction + // NOLINTNEXTLINE + : std::conditional, P1>::type {}; + +template +using void_t = void; + +// Detects whether an expression of type `From` can be implicitly converted to +// `To` according to [conv]. In C++17, [conv]/3 defines this as follows: +// +// An expression e can be implicitly converted to a type T if and only if +// the declaration T t=e; is well-formed, for some invented temporary +// variable t ([dcl.init]). +// +// [conv]/2 implies we can use function argument passing to detect whether this +// initialization is valid. +// +// Note that this is distinct from is_convertible, which requires this be valid: +// +// To test() { +// return declval(); +// } +// +// In particular, is_convertible doesn't give the correct answer when `To` and +// `From` are the same non-moveable type since `declval` will be an rvalue +// reference, defeating the guaranteed copy elision that would otherwise make +// this function work. +// +// REQUIRES: `From` is not cv void. +template +struct is_implicitly_convertible { + private: + // A function that accepts a parameter of type T. This can be called with type + // U successfully only if U is implicitly convertible to T. + template + static void Accept(T); + + // A function that creates a value of type T. + template + static T Make(); + + // An overload be selected when implicit conversion from T to To is possible. + template (Make()))> + static std::true_type TestImplicitConversion(int); + + // A fallback overload selected in all other cases. + template + static std::false_type TestImplicitConversion(...); + + public: + using type = decltype(TestImplicitConversion(0)); + static constexpr bool value = type::value; +}; + +// Like std::invoke_result_t from C++17, but works only for objects with call +// operators (not e.g. member function pointers, which we don't need specific +// support for in OnceAction because std::function deals with them). +template +using call_result_t = decltype(std::declval()(std::declval()...)); + +template +struct is_callable_r_impl : std::false_type {}; + +// Specialize the struct for those template arguments where call_result_t is +// well-formed. When it's not, the generic template above is chosen, resulting +// in std::false_type. +template +struct is_callable_r_impl>, R, F, Args...> + : std::conditional< + std::is_void::value, // + std::true_type, // + is_implicitly_convertible, R>>::type {}; + +// Like std::is_invocable_r from C++17, but works only for objects with call +// operators. See the note on call_result_t. +template +using is_callable_r = is_callable_r_impl; + +// Like std::as_const from C++17. +template +typename std::add_const::type& as_const(T& t) { + return t; +} + +} // namespace internal + +// Specialized for function types below. +template +class OnceAction; + +// An action that can only be used once. +// +// This is accepted by WillOnce, which doesn't require the underlying action to +// be copy-constructible (only move-constructible), and promises to invoke it as +// an rvalue reference. This allows the action to work with move-only types like +// std::move_only_function in a type-safe manner. +// +// For example: +// +// // Assume we have some API that needs to accept a unique pointer to some +// // non-copyable object Foo. +// void AcceptUniquePointer(std::unique_ptr foo); +// +// // We can define an action that provides a Foo to that API. Because It +// // has to give away its unique pointer, it must not be called more than +// // once, so its call operator is &&-qualified. +// struct ProvideFoo { +// std::unique_ptr foo; +// +// void operator()() && { +// AcceptUniquePointer(std::move(Foo)); +// } +// }; +// +// // This action can be used with WillOnce. +// EXPECT_CALL(mock, Call) +// .WillOnce(ProvideFoo{std::make_unique(...)}); +// +// // But a call to WillRepeatedly will fail to compile. This is correct, +// // since the action cannot correctly be used repeatedly. +// EXPECT_CALL(mock, Call) +// .WillRepeatedly(ProvideFoo{std::make_unique(...)}); +// +// A less-contrived example would be an action that returns an arbitrary type, +// whose &&-qualified call operator is capable of dealing with move-only types. +template +class OnceAction final { + private: + // True iff we can use the given callable type (or lvalue reference) directly + // via StdFunctionAdaptor. + template + using IsDirectlyCompatible = internal::conjunction< + // It must be possible to capture the callable in StdFunctionAdaptor. + std::is_constructible::type, Callable>, + // The callable must be compatible with our signature. + internal::is_callable_r::type, + Args...>>; + + // True iff we can use the given callable type via StdFunctionAdaptor once we + // ignore incoming arguments. + template + using IsCompatibleAfterIgnoringArguments = internal::conjunction< + // It must be possible to capture the callable in a lambda. + std::is_constructible::type, Callable>, + // The callable must be invocable with zero arguments, returning something + // convertible to Result. + internal::is_callable_r::type>>; + + public: + // Construct from a callable that is directly compatible with our mocked + // signature: it accepts our function type's arguments and returns something + // convertible to our result type. + template ::type>>, + IsDirectlyCompatible> // + ::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + : function_(StdFunctionAdaptor::type>( + {}, std::forward(callable))) {} + + // As above, but for a callable that ignores the mocked function's arguments. + template ::type>>, + // Exclude callables for which the overload above works. + // We'd rather provide the arguments if possible. + internal::negation>, + IsCompatibleAfterIgnoringArguments>::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + // Call the constructor above with a callable + // that ignores the input arguments. + : OnceAction(IgnoreIncomingArguments::type>{ + std::forward(callable)}) {} + + // We are naturally copyable because we store only an std::function, but + // semantically we should not be copyable. + OnceAction(const OnceAction&) = delete; + OnceAction& operator=(const OnceAction&) = delete; + OnceAction(OnceAction&&) = default; + + // Invoke the underlying action callable with which we were constructed, + // handing it the supplied arguments. + Result Call(Args... args) && { + return function_(std::forward(args)...); + } + + private: + // An adaptor that wraps a callable that is compatible with our signature and + // being invoked as an rvalue reference so that it can be used as an + // StdFunctionAdaptor. This throws away type safety, but that's fine because + // this is only used by WillOnce, which we know calls at most once. + // + // Once we have something like std::move_only_function from C++23, we can do + // away with this. + template + class StdFunctionAdaptor final { + public: + // A tag indicating that the (otherwise universal) constructor is accepting + // the callable itself, instead of e.g. stealing calls for the move + // constructor. + struct CallableTag final {}; + + template + explicit StdFunctionAdaptor(CallableTag, F&& callable) + : callable_(std::make_shared(std::forward(callable))) {} + + // Rather than explicitly returning Result, we return whatever the wrapped + // callable returns. This allows for compatibility with existing uses like + // the following, when the mocked function returns void: + // + // EXPECT_CALL(mock_fn_, Call) + // .WillOnce([&] { + // [...] + // return 0; + // }); + // + // Such a callable can be turned into std::function. If we use an + // explicit return type of Result here then it *doesn't* work with + // std::function, because we'll get a "void function should not return a + // value" error. + // + // We need not worry about incompatible result types because the SFINAE on + // OnceAction already checks this for us. std::is_invocable_r_v itself makes + // the same allowance for void result types. + template + internal::call_result_t operator()( + ArgRefs&&... args) const { + return std::move(*callable_)(std::forward(args)...); + } + + private: + // We must put the callable on the heap so that we are copyable, which + // std::function needs. + std::shared_ptr callable_; + }; + + // An adaptor that makes a callable that accepts zero arguments callable with + // our mocked arguments. + template + struct IgnoreIncomingArguments { + internal::call_result_t operator()(Args&&...) { + return std::move(callable)(); + } + + Callable callable; + }; + + std::function function_; +}; + +// When an unexpected function call is encountered, Google Mock will +// let it return a default value if the user has specified one for its +// return type, or if the return type has a built-in default value; +// otherwise Google Mock won't know what value to return and will have +// to abort the process. +// +// The DefaultValue class allows a user to specify the +// default value for a type T that is both copyable and publicly +// destructible (i.e. anything that can be used as a function return +// type). The usage is: +// +// // Sets the default value for type T to be foo. +// DefaultValue::Set(foo); +template +class DefaultValue { + public: + // Sets the default value for type T; requires T to be + // copy-constructable and have a public destructor. + static void Set(T x) { + delete producer_; + producer_ = new FixedValueProducer(x); + } + + // Provides a factory function to be called to generate the default value. + // This method can be used even if T is only move-constructible, but it is not + // limited to that case. + typedef T (*FactoryFunction)(); + static void SetFactory(FactoryFunction factory) { + delete producer_; + producer_ = new FactoryValueProducer(factory); + } + + // Unsets the default value for type T. + static void Clear() { + delete producer_; + producer_ = nullptr; + } + + // Returns true if and only if the user has set the default value for type T. + static bool IsSet() { return producer_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T if the user has set one; + // otherwise returns the built-in default value. Requires that Exists() + // is true, which ensures that the return value is well-defined. + static T Get() { + return producer_ == nullptr ? internal::BuiltInDefaultValue::Get() + : producer_->Produce(); + } + + private: + class ValueProducer { + public: + virtual ~ValueProducer() = default; + virtual T Produce() = 0; + }; + + class FixedValueProducer : public ValueProducer { + public: + explicit FixedValueProducer(T value) : value_(value) {} + T Produce() override { return value_; } + + private: + const T value_; + FixedValueProducer(const FixedValueProducer&) = delete; + FixedValueProducer& operator=(const FixedValueProducer&) = delete; + }; + + class FactoryValueProducer : public ValueProducer { + public: + explicit FactoryValueProducer(FactoryFunction factory) + : factory_(factory) {} + T Produce() override { return factory_(); } + + private: + const FactoryFunction factory_; + FactoryValueProducer(const FactoryValueProducer&) = delete; + FactoryValueProducer& operator=(const FactoryValueProducer&) = delete; + }; + + static ValueProducer* producer_; +}; + +// This partial specialization allows a user to set default values for +// reference types. +template +class DefaultValue { + public: + // Sets the default value for type T&. + static void Set(T& x) { // NOLINT + address_ = &x; + } + + // Unsets the default value for type T&. + static void Clear() { address_ = nullptr; } + + // Returns true if and only if the user has set the default value for type T&. + static bool IsSet() { return address_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T& if the user has set one; + // otherwise returns the built-in default value if there is one; + // otherwise aborts the process. + static T& Get() { + return address_ == nullptr ? internal::BuiltInDefaultValue::Get() + : *address_; + } + + private: + static T* address_; +}; + +// This specialization allows DefaultValue::Get() to +// compile. +template <> +class DefaultValue { + public: + static bool Exists() { return true; } + static void Get() {} +}; + +// Points to the user-set default value for type T. +template +typename DefaultValue::ValueProducer* DefaultValue::producer_ = nullptr; + +// Points to the user-set default value for type T&. +template +T* DefaultValue::address_ = nullptr; + +// Implement this interface to define an action for function type F. +template +class ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + ActionInterface() = default; + virtual ~ActionInterface() = default; + + // Performs the action. This method is not const, as in general an + // action can have side effects and be stateful. For example, a + // get-the-next-element-from-the-collection action will need to + // remember the current element. + virtual Result Perform(const ArgumentTuple& args) = 0; + + private: + ActionInterface(const ActionInterface&) = delete; + ActionInterface& operator=(const ActionInterface&) = delete; +}; + +template +class Action; + +// An Action is a copyable and IMMUTABLE (except by assignment) +// object that represents an action to be taken when a mock function of type +// R(Args...) is called. The implementation of Action is just a +// std::shared_ptr to const ActionInterface. Don't inherit from Action! You +// can view an object implementing ActionInterface as a concrete action +// (including its current state), and an Action object as a handle to it. +template +class Action { + private: + using F = R(Args...); + + // Adapter class to allow constructing Action from a legacy ActionInterface. + // New code should create Actions from functors instead. + struct ActionAdapter { + // Adapter must be copyable to satisfy std::function requirements. + ::std::shared_ptr> impl_; + + template + typename internal::Function::Result operator()(InArgs&&... args) { + return impl_->Perform( + ::std::forward_as_tuple(::std::forward(args)...)); + } + }; + + template + using IsCompatibleFunctor = std::is_constructible, G>; + + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + // Constructs a null Action. Needed for storing Action objects in + // STL containers. + Action() = default; + + // Construct an Action from a specified callable. + // This cannot take std::function directly, because then Action would not be + // directly constructible from lambda (it would require two conversions). + template < + typename G, + typename = typename std::enable_if, std::is_constructible, + G>>::value>::type> + Action(G&& fun) { // NOLINT + Init(::std::forward(fun), IsCompatibleFunctor()); + } + + // Constructs an Action from its implementation. + explicit Action(ActionInterface* impl) + : fun_(ActionAdapter{::std::shared_ptr>(impl)}) {} + + // This constructor allows us to turn an Action object into an + // Action, as long as F's arguments can be implicitly converted + // to Func's and Func's return type can be implicitly converted to F's. + template + Action(const Action& action) // NOLINT + : fun_(action.fun_) {} + + // Returns true if and only if this is the DoDefault() action. + bool IsDoDefault() const { return fun_ == nullptr; } + + // Performs the action. Note that this method is const even though + // the corresponding method in ActionInterface is not. The reason + // is that a const Action means that it cannot be re-bound to + // another concrete action, not that the concrete action it binds to + // cannot change state. (Think of the difference between a const + // pointer and a pointer to const.) + Result Perform(ArgumentTuple args) const { + if (IsDoDefault()) { + internal::IllegalDoDefault(__FILE__, __LINE__); + } + return internal::Apply(fun_, ::std::move(args)); + } + + // An action can be used as a OnceAction, since it's obviously safe to call it + // once. + operator OnceAction() const { // NOLINT + // Return a OnceAction-compatible callable that calls Perform with the + // arguments it is provided. We could instead just return fun_, but then + // we'd need to handle the IsDoDefault() case separately. + struct OA { + Action action; + + R operator()(Args... args) && { + return action.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{*this}; + } + + private: + template + friend class Action; + + template + void Init(G&& g, ::std::true_type) { + fun_ = ::std::forward(g); + } + + template + void Init(G&& g, ::std::false_type) { + fun_ = IgnoreArgs::type>{::std::forward(g)}; + } + + template + struct IgnoreArgs { + template + Result operator()(const InArgs&...) const { + return function_impl(); + } + + FunctionImpl function_impl; + }; + + // fun_ is an empty function if and only if this is the DoDefault() action. + ::std::function fun_; +}; + +// The PolymorphicAction class template makes it easy to implement a +// polymorphic action (i.e. an action that can be used in mock +// functions of than one type, e.g. Return()). +// +// To define a polymorphic action, a user first provides a COPYABLE +// implementation class that has a Perform() method template: +// +// class FooAction { +// public: +// template +// Result Perform(const ArgumentTuple& args) const { +// // Processes the arguments and returns a result, using +// // std::get(args) to get the N-th (0-based) argument in the tuple. +// } +// ... +// }; +// +// Then the user creates the polymorphic action using +// MakePolymorphicAction(object) where object has type FooAction. See +// the definition of Return(void) and SetArgumentPointee(value) for +// complete examples. +template +class PolymorphicAction { + public: + explicit PolymorphicAction(const Impl& impl) : impl_(impl) {} + + template + operator Action() const { + return Action(new MonomorphicImpl(impl_)); + } + + private: + template + class MonomorphicImpl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit MonomorphicImpl(const Impl& impl) : impl_(impl) {} + + Result Perform(const ArgumentTuple& args) override { + return impl_.template Perform(args); + } + + private: + Impl impl_; + }; + + Impl impl_; +}; + +// Creates an Action from its implementation and returns it. The +// created Action object owns the implementation. +template +Action MakeAction(ActionInterface* impl) { + return Action(impl); +} + +// Creates a polymorphic action from its implementation. This is +// easier to use than the PolymorphicAction constructor as it +// doesn't require you to explicitly write the template argument, e.g. +// +// MakePolymorphicAction(foo); +// vs +// PolymorphicAction(foo); +template +inline PolymorphicAction MakePolymorphicAction(const Impl& impl) { + return PolymorphicAction(impl); +} + +namespace internal { + +// Helper struct to specialize ReturnAction to execute a move instead of a copy +// on return. Useful for move-only types, but could be used on any type. +template +struct ByMoveWrapper { + explicit ByMoveWrapper(T value) : payload(std::move(value)) {} + T payload; +}; + +// The general implementation of Return(R). Specializations follow below. +template +class ReturnAction final { + public: + explicit ReturnAction(R value) : value_(std::move(value)) {} + + template >, // + negation>, // + std::is_convertible, // + std::is_move_constructible>::value>::type> + operator OnceAction() && { // NOLINT + return Impl(std::move(value_)); + } + + template >, // + negation>, // + std::is_convertible, // + std::is_copy_constructible>::value>::type> + operator Action() const { // NOLINT + return Impl(value_); + } + + private: + // Implements the Return(x) action for a mock function that returns type U. + template + class Impl final { + public: + // The constructor used when the return value is allowed to move from the + // input value (i.e. we are converting to OnceAction). + explicit Impl(R&& input_value) + : state_(new State(std::move(input_value))) {} + + // The constructor used when the return value is not allowed to move from + // the input value (i.e. we are converting to Action). + explicit Impl(const R& input_value) : state_(new State(input_value)) {} + + U operator()() && { return std::move(state_->value); } + U operator()() const& { return state_->value; } + + private: + // We put our state on the heap so that the compiler-generated copy/move + // constructors work correctly even when U is a reference-like type. This is + // necessary only because we eagerly create State::value (see the note on + // that symbol for details). If we instead had only the input value as a + // member then the default constructors would work fine. + // + // For example, when R is std::string and U is std::string_view, value is a + // reference to the string backed by input_value. The copy constructor would + // copy both, so that we wind up with a new input_value object (with the + // same contents) and a reference to the *old* input_value object rather + // than the new one. + struct State { + explicit State(const R& input_value_in) + : input_value(input_value_in), + // Make an implicit conversion to Result before initializing the U + // object we store, avoiding calling any explicit constructor of U + // from R. + // + // This simulates the language rules: a function with return type U + // that does `return R()` requires R to be implicitly convertible to + // U, and uses that path for the conversion, even U Result has an + // explicit constructor from R. + value(ImplicitCast_(internal::as_const(input_value))) {} + + // As above, but for the case where we're moving from the ReturnAction + // object because it's being used as a OnceAction. + explicit State(R&& input_value_in) + : input_value(std::move(input_value_in)), + // For the same reason as above we make an implicit conversion to U + // before initializing the value. + // + // Unlike above we provide the input value as an rvalue to the + // implicit conversion because this is a OnceAction: it's fine if it + // wants to consume the input value. + value(ImplicitCast_(std::move(input_value))) {} + + // A copy of the value originally provided by the user. We retain this in + // addition to the value of the mock function's result type below in case + // the latter is a reference-like type. See the std::string_view example + // in the documentation on Return. + R input_value; + + // The value we actually return, as the type returned by the mock function + // itself. + // + // We eagerly initialize this here, rather than lazily doing the implicit + // conversion automatically each time Perform is called, for historical + // reasons: in 2009-11, commit a070cbd91c (Google changelist 13540126) + // made the Action conversion operator eagerly convert the R value to + // U, but without keeping the R alive. This broke the use case discussed + // in the documentation for Return, making reference-like types such as + // std::string_view not safe to use as U where the input type R is a + // value-like type such as std::string. + // + // The example the commit gave was not very clear, nor was the issue + // thread (https://github.com/google/googlemock/issues/86), but it seems + // the worry was about reference-like input types R that flatten to a + // value-like type U when being implicitly converted. An example of this + // is std::vector::reference, which is often a proxy type with an + // reference to the underlying vector: + // + // // Helper method: have the mock function return bools according + // // to the supplied script. + // void SetActions(MockFunction& mock, + // const std::vector& script) { + // for (size_t i = 0; i < script.size(); ++i) { + // EXPECT_CALL(mock, Call(i)).WillOnce(Return(script[i])); + // } + // } + // + // TEST(Foo, Bar) { + // // Set actions using a temporary vector, whose operator[] + // // returns proxy objects that references that will be + // // dangling once the call to SetActions finishes and the + // // vector is destroyed. + // MockFunction mock; + // SetActions(mock, {false, true}); + // + // EXPECT_FALSE(mock.AsStdFunction()(0)); + // EXPECT_TRUE(mock.AsStdFunction()(1)); + // } + // + // This eager conversion helps with a simple case like this, but doesn't + // fully make these types work in general. For example the following still + // uses a dangling reference: + // + // TEST(Foo, Baz) { + // MockFunction()> mock; + // + // // Return the same vector twice, and then the empty vector + // // thereafter. + // auto action = Return(std::initializer_list{ + // "taco", "burrito", + // }); + // + // EXPECT_CALL(mock, Call) + // .WillOnce(action) + // .WillOnce(action) + // .WillRepeatedly(Return(std::vector{})); + // + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), IsEmpty()); + // } + // + U value; + }; + + const std::shared_ptr state_; + }; + + R value_; +}; + +// A specialization of ReturnAction when R is ByMoveWrapper for some T. +// +// This version applies the type system-defeating hack of moving from T even in +// the const call operator, checking at runtime that it isn't called more than +// once, since the user has declared their intent to do so by using ByMove. +template +class ReturnAction> final { + public: + explicit ReturnAction(ByMoveWrapper wrapper) + : state_(new State(std::move(wrapper.payload))) {} + + T operator()() const { + GTEST_CHECK_(!state_->called) + << "A ByMove() action must be performed at most once."; + + state_->called = true; + return std::move(state_->value); + } + + private: + // We store our state on the heap so that we are copyable as required by + // Action, despite the fact that we are stateful and T may not be copyable. + struct State { + explicit State(T&& value_in) : value(std::move(value_in)) {} + + T value; + bool called = false; + }; + + const std::shared_ptr state_; +}; + +// Implements the ReturnNull() action. +class ReturnNullAction { + public: + // Allows ReturnNull() to be used in any pointer-returning function. In C++11 + // this is enforced by returning nullptr, and in non-C++11 by asserting a + // pointer type on compile time. + template + static Result Perform(const ArgumentTuple&) { + return nullptr; + } +}; + +// Implements the Return() action. +class ReturnVoidAction { + public: + // Allows Return() to be used in any void-returning function. + template + static void Perform(const ArgumentTuple&) { + static_assert(std::is_void::value, "Result should be void."); + } +}; + +// Implements the polymorphic ReturnRef(x) action, which can be used +// in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefAction { + public: + // Constructs a ReturnRefAction object from the reference to be returned. + explicit ReturnRefAction(T& ref) : ref_(ref) {} // NOLINT + + // This template type conversion operator allows ReturnRef(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRef(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRef to return a value"); + return Action(new Impl(ref_)); + } + + private: + // Implements the ReturnRef(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(T& ref) : ref_(ref) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return ref_; } + + private: + T& ref_; + }; + + T& ref_; +}; + +// Implements the polymorphic ReturnRefOfCopy(x) action, which can be +// used in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefOfCopyAction { + public: + // Constructs a ReturnRefOfCopyAction object from the reference to + // be returned. + explicit ReturnRefOfCopyAction(const T& value) : value_(value) {} // NOLINT + + // This template type conversion operator allows ReturnRefOfCopy(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRefOfCopy(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRefOfCopy to return a value"); + return Action(new Impl(value_)); + } + + private: + // Implements the ReturnRefOfCopy(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const T& value) : value_(value) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return value_; } + + private: + T value_; + }; + + const T value_; +}; + +// Implements the polymorphic ReturnRoundRobin(v) action, which can be +// used in any function that returns the element_type of v. +template +class ReturnRoundRobinAction { + public: + explicit ReturnRoundRobinAction(std::vector values) { + GTEST_CHECK_(!values.empty()) + << "ReturnRoundRobin requires at least one element."; + state_->values = std::move(values); + } + + template + T operator()(Args&&...) const { + return state_->Next(); + } + + private: + struct State { + T Next() { + T ret_val = values[i++]; + if (i == values.size()) i = 0; + return ret_val; + } + + std::vector values; + size_t i = 0; + }; + std::shared_ptr state_ = std::make_shared(); +}; + +// Implements the polymorphic DoDefault() action. +class DoDefaultAction { + public: + // This template type conversion operator allows DoDefault() to be + // used in any function. + template + operator Action() const { + return Action(); + } // NOLINT +}; + +// Implements the Assign action to set a given pointer referent to a +// particular value. +template +class AssignAction { + public: + AssignAction(T1* ptr, T2 value) : ptr_(ptr), value_(value) {} + + template + void Perform(const ArgumentTuple& /* args */) const { + *ptr_ = value_; + } + + private: + T1* const ptr_; + const T2 value_; +}; + +#ifndef GTEST_OS_WINDOWS_MOBILE + +// Implements the SetErrnoAndReturn action to simulate return from +// various system calls and libc functions. +template +class SetErrnoAndReturnAction { + public: + SetErrnoAndReturnAction(int errno_value, T result) + : errno_(errno_value), result_(result) {} + template + Result Perform(const ArgumentTuple& /* args */) const { + errno = errno_; + return result_; + } + + private: + const int errno_; + const T result_; +}; + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Implements the SetArgumentPointee(x) action for any function +// whose N-th argument (0-based) is a pointer to x's type. +template +struct SetArgumentPointeeAction { + A value; + + template + void operator()(const Args&... args) const { + *::std::get(std::tie(args...)) = value; + } +}; + +// Implements the Invoke(object_ptr, &Class::Method) action. +template +struct InvokeMethodAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + template + auto operator()(Args&&... args) const + -> decltype((obj_ptr->*method_ptr)(std::forward(args)...)) { + return (obj_ptr->*method_ptr)(std::forward(args)...); + } +}; + +// Implements the InvokeWithoutArgs(f) action. The template argument +// FunctionImpl is the implementation type of f, which can be either a +// function pointer or a functor. InvokeWithoutArgs(f) can be used as an +// Action as long as f's type is compatible with F. +template +struct InvokeWithoutArgsAction { + FunctionImpl function_impl; + + // Allows InvokeWithoutArgs(f) to be used as any action whose type is + // compatible with f. + template + auto operator()(const Args&...) -> decltype(function_impl()) { + return function_impl(); + } +}; + +// Implements the InvokeWithoutArgs(object_ptr, &Class::Method) action. +template +struct InvokeMethodWithoutArgsAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + using ReturnType = + decltype((std::declval()->*std::declval())()); + + template + ReturnType operator()(const Args&...) const { + return (obj_ptr->*method_ptr)(); + } +}; + +// Implements the IgnoreResult(action) action. +template +class IgnoreResultAction { + public: + explicit IgnoreResultAction(const A& action) : action_(action) {} + + template + operator Action() const { + // Assert statement belongs here because this is the best place to verify + // conditions on F. It produces the clearest error messages + // in most compilers. + // Impl really belongs in this scope as a local class but can't + // because MSVC produces duplicate symbols in different translation units + // in this case. Until MS fixes that bug we put Impl into the class scope + // and put the typedef both here (for use in assert statement) and + // in the Impl class. But both definitions must be the same. + typedef typename internal::Function::Result Result; + + // Asserts at compile time that F returns void. + static_assert(std::is_void::value, "Result type should be void."); + + return Action(new Impl(action_)); + } + + private: + template + class Impl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const A& action) : action_(action) {} + + void Perform(const ArgumentTuple& args) override { + // Performs the action and ignores its result. + action_.Perform(args); + } + + private: + // Type OriginalFunction is the same as F except that its return + // type is IgnoredValue. + typedef + typename internal::Function::MakeResultIgnoredValue OriginalFunction; + + const Action action_; + }; + + const A action_; +}; + +template +struct WithArgsAction { + InnerAction inner_action; + + // The signature of the function as seen by the inner action, given an out + // action with the given result and argument types. + template + using InnerSignature = + R(typename std::tuple_element>::type...); + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>...)>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + struct OA { + OnceAction> inner_action; + + R operator()(Args&&... args) && { + return std::move(inner_action) + .Call(std::get( + std::forward_as_tuple(std::forward(args)...))...); + } + }; + + return OA{std::move(inner_action)}; + } + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>...)>>::value, + int>::type = 0> + operator Action() const { // NOLINT + Action> converted(inner_action); + + return [converted](Args&&... args) -> R { + return converted.Perform(std::forward_as_tuple( + std::get(std::forward_as_tuple(std::forward(args)...))...)); + }; + } +}; + +template +class DoAllAction; + +// Base case: only a single action. +template +class DoAllAction { + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& action) + : final_action_(std::forward(action)) {} + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + template >::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + return std::move(final_action_); + } + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>::value, + int>::type = 0> + operator Action() const { // NOLINT + return final_action_; + } + + private: + FinalAction final_action_; +}; + +// Recursive case: support N actions by calling the initial action and then +// calling through to the base class containing N-1 actions. +template +class DoAllAction + : private DoAllAction { + private: + using Base = DoAllAction; + + // The type of reference that should be provided to an initial action for a + // mocked function parameter of type T. + // + // There are two quirks here: + // + // * Unlike most forwarding functions, we pass scalars through by value. + // This isn't strictly necessary because an lvalue reference would work + // fine too and be consistent with other non-reference types, but it's + // perhaps less surprising. + // + // For example if the mocked function has signature void(int), then it + // might seem surprising for the user's initial action to need to be + // convertible to Action. This is perhaps less + // surprising for a non-scalar type where there may be a performance + // impact, or it might even be impossible, to pass by value. + // + // * More surprisingly, `const T&` is often not a const reference type. + // By the reference collapsing rules in C++17 [dcl.ref]/6, if T refers to + // U& or U&& for some non-scalar type U, then InitialActionArgType is + // U&. In other words, we may hand over a non-const reference. + // + // So for example, given some non-scalar type Obj we have the following + // mappings: + // + // T InitialActionArgType + // ------- ----------------------- + // Obj const Obj& + // Obj& Obj& + // Obj&& Obj& + // const Obj const Obj& + // const Obj& const Obj& + // const Obj&& const Obj& + // + // In other words, the initial actions get a mutable view of an non-scalar + // argument if and only if the mock function itself accepts a non-const + // reference type. They are never given an rvalue reference to an + // non-scalar type. + // + // This situation makes sense if you imagine use with a matcher that is + // designed to write through a reference. For example, if the caller wants + // to fill in a reference argument and then return a canned value: + // + // EXPECT_CALL(mock, Call) + // .WillOnce(DoAll(SetArgReferee<0>(17), Return(19))); + // + template + using InitialActionArgType = + typename std::conditional::value, T, const T&>::type; + + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& initial_action, + U&&... other_actions) + : Base({}, std::forward(other_actions)...), + initial_action_(std::forward(initial_action)) {} + + template ...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + OnceAction...)> initial_action; + OnceAction remaining_actions; + + R operator()(Args... args) && { + std::move(initial_action) + .Call(static_cast>(args)...); + + return std::move(remaining_actions).Call(std::forward(args)...); + } + }; + + return OA{ + std::move(initial_action_), + std::move(static_cast(*this)), + }; + } + + template < + typename R, typename... Args, + typename std::enable_if< + conjunction< + // Both the initial action and the rest must support conversion to + // Action. + std::is_convertible...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator Action() const { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + Action...)> initial_action; + Action remaining_actions; + + R operator()(Args... args) const { + initial_action.Perform(std::forward_as_tuple( + static_cast>(args)...)); + + return remaining_actions.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{ + initial_action_, + static_cast(*this), + }; + } + + private: + InitialAction initial_action_; +}; + +template +struct ReturnNewAction { + T* operator()() const { + return internal::Apply( + [](const Params&... unpacked_params) { + return new T(unpacked_params...); + }, + params); + } + std::tuple params; +}; + +template +struct ReturnArgAction { + template ::type> + auto operator()(Args&&... args) const -> decltype(std::get( + std::forward_as_tuple(std::forward(args)...))) { + return std::get(std::forward_as_tuple(std::forward(args)...)); + } +}; + +template +struct SaveArgAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = std::get(std::tie(args...)); + } +}; + +template +struct SaveArgPointeeAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = *std::get(std::tie(args...)); + } +}; + +template +struct SetArgRefereeAction { + T value; + + template + void operator()(Args&&... args) const { + using argk_type = + typename ::std::tuple_element>::type; + static_assert(std::is_lvalue_reference::value, + "Argument must be a reference type."); + std::get(std::tie(args...)) = value; + } +}; + +template +struct SetArrayArgumentAction { + I1 first; + I2 last; + + template + void operator()(const Args&... args) const { + auto value = std::get(std::tie(args...)); + for (auto it = first; it != last; ++it, (void)++value) { + *value = *it; + } + } +}; + +template +struct DeleteArgAction { + template + void operator()(const Args&... args) const { + delete std::get(std::tie(args...)); + } +}; + +template +struct ReturnPointeeAction { + Ptr pointer; + template + auto operator()(const Args&...) const -> decltype(*pointer) { + return *pointer; + } +}; + +#if GTEST_HAS_EXCEPTIONS +template +struct ThrowAction { + T exception; + // We use a conversion operator to adapt to any return type. + template + operator Action() const { // NOLINT + T copy = exception; + return [copy](Args...) -> R { throw copy; }; + } +}; +#endif // GTEST_HAS_EXCEPTIONS + +} // namespace internal + +// An Unused object can be implicitly constructed from ANY value. +// This is handy when defining actions that ignore some or all of the +// mock function arguments. For example, given +// +// MOCK_METHOD3(Foo, double(const string& label, double x, double y)); +// MOCK_METHOD3(Bar, double(int index, double x, double y)); +// +// instead of +// +// double DistanceToOriginWithLabel(const string& label, double x, double y) { +// return sqrt(x*x + y*y); +// } +// double DistanceToOriginWithIndex(int index, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)) +// .WillOnce(Invoke(DistanceToOriginWithLabel)); +// EXPECT_CALL(mock, Bar(5, _, _)) +// .WillOnce(Invoke(DistanceToOriginWithIndex)); +// +// you could write +// +// // We can declare any uninteresting argument as Unused. +// double DistanceToOrigin(Unused, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)).WillOnce(Invoke(DistanceToOrigin)); +// EXPECT_CALL(mock, Bar(5, _, _)).WillOnce(Invoke(DistanceToOrigin)); +typedef internal::IgnoredValue Unused; + +// Creates an action that does actions a1, a2, ..., sequentially in +// each invocation. All but the last action will have a readonly view of the +// arguments. +template +internal::DoAllAction::type...> DoAll( + Action&&... action) { + return internal::DoAllAction::type...>( + {}, std::forward(action)...); +} + +// WithArg(an_action) creates an action that passes the k-th +// (0-based) argument of the mock function to an_action and performs +// it. It adapts an action accepting one argument to one that accepts +// multiple arguments. For convenience, we also provide +// WithArgs(an_action) (defined below) as a synonym. +template +internal::WithArgsAction::type, k> WithArg( + InnerAction&& action) { + return {std::forward(action)}; +} + +// WithArgs(an_action) creates an action that passes +// the selected arguments of the mock function to an_action and +// performs it. It serves as an adaptor between actions with +// different argument lists. +template +internal::WithArgsAction::type, k, ks...> +WithArgs(InnerAction&& action) { + return {std::forward(action)}; +} + +// WithoutArgs(inner_action) can be used in a mock function with a +// non-empty argument list to perform inner_action, which takes no +// argument. In other words, it adapts an action accepting no +// argument to one that accepts (and ignores) arguments. +template +internal::WithArgsAction::type> WithoutArgs( + InnerAction&& action) { + return {std::forward(action)}; +} + +// Creates an action that returns a value. +// +// The returned type can be used with a mock function returning a non-void, +// non-reference type U as follows: +// +// * If R is convertible to U and U is move-constructible, then the action can +// be used with WillOnce. +// +// * If const R& is convertible to U and U is copy-constructible, then the +// action can be used with both WillOnce and WillRepeatedly. +// +// The mock expectation contains the R value from which the U return value is +// constructed (a move/copy of the argument to Return). This means that the R +// value will survive at least until the mock object's expectations are cleared +// or the mock object is destroyed, meaning that U can safely be a +// reference-like type such as std::string_view: +// +// // The mock function returns a view of a copy of the string fed to +// // Return. The view is valid even after the action is performed. +// MockFunction mock; +// EXPECT_CALL(mock, Call).WillOnce(Return(std::string("taco"))); +// const std::string_view result = mock.AsStdFunction()(); +// EXPECT_EQ("taco", result); +// +template +internal::ReturnAction Return(R value) { + return internal::ReturnAction(std::move(value)); +} + +// Creates an action that returns NULL. +inline PolymorphicAction ReturnNull() { + return MakePolymorphicAction(internal::ReturnNullAction()); +} + +// Creates an action that returns from a void function. +inline PolymorphicAction Return() { + return MakePolymorphicAction(internal::ReturnVoidAction()); +} + +// Creates an action that returns the reference to a variable. +template +inline internal::ReturnRefAction ReturnRef(R& x) { // NOLINT + return internal::ReturnRefAction(x); +} + +// Prevent using ReturnRef on reference to temporary. +template +internal::ReturnRefAction ReturnRef(R&&) = delete; + +// Creates an action that returns the reference to a copy of the +// argument. The copy is created when the action is constructed and +// lives as long as the action. +template +inline internal::ReturnRefOfCopyAction ReturnRefOfCopy(const R& x) { + return internal::ReturnRefOfCopyAction(x); +} + +// DEPRECATED: use Return(x) directly with WillOnce. +// +// Modifies the parent action (a Return() action) to perform a move of the +// argument instead of a copy. +// Return(ByMove()) actions can only be executed once and will assert this +// invariant. +template +internal::ByMoveWrapper ByMove(R x) { + return internal::ByMoveWrapper(std::move(x)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin(std::vector vals) { + return internal::ReturnRoundRobinAction(std::move(vals)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin( + std::initializer_list vals) { + return internal::ReturnRoundRobinAction(std::vector(vals)); +} + +// Creates an action that does the default action for the give mock function. +inline internal::DoDefaultAction DoDefault() { + return internal::DoDefaultAction(); +} + +// Creates an action that sets the variable pointed by the N-th +// (0-based) function argument to 'value'. +template +internal::SetArgumentPointeeAction SetArgPointee(T value) { + return {std::move(value)}; +} + +// The following version is DEPRECATED. +template +internal::SetArgumentPointeeAction SetArgumentPointee(T value) { + return {std::move(value)}; +} + +// Creates an action that sets a pointer referent to a given value. +template +PolymorphicAction> Assign(T1* ptr, T2 val) { + return MakePolymorphicAction(internal::AssignAction(ptr, val)); +} + +#ifndef GTEST_OS_WINDOWS_MOBILE + +// Creates an action that sets errno and returns the appropriate error. +template +PolymorphicAction> SetErrnoAndReturn( + int errval, T result) { + return MakePolymorphicAction( + internal::SetErrnoAndReturnAction(errval, result)); +} + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Various overloads for Invoke(). + +// Legacy function. +// Actions can now be implicitly constructed from callables. No need to create +// wrapper objects. +// This function exists for backwards compatibility. +template +typename std::decay::type Invoke(FunctionImpl&& function_impl) { + return std::forward(function_impl); +} + +// Creates an action that invokes the given method on the given object +// with the mock function's arguments. +template +internal::InvokeMethodAction Invoke(Class* obj_ptr, + MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that invokes 'function_impl' with no argument. +template +internal::InvokeWithoutArgsAction::type> +InvokeWithoutArgs(FunctionImpl function_impl) { + return {std::move(function_impl)}; +} + +// Creates an action that invokes the given method on the given object +// with no argument. +template +internal::InvokeMethodWithoutArgsAction InvokeWithoutArgs( + Class* obj_ptr, MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that performs an_action and throws away its +// result. In other words, it changes the return type of an_action to +// void. an_action MUST NOT return void, or the code won't compile. +template +inline internal::IgnoreResultAction IgnoreResult(const A& an_action) { + return internal::IgnoreResultAction(an_action); +} + +// Creates a reference wrapper for the given L-value. If necessary, +// you can explicitly specify the type of the reference. For example, +// suppose 'derived' is an object of type Derived, ByRef(derived) +// would wrap a Derived&. If you want to wrap a const Base& instead, +// where Base is a base class of Derived, just write: +// +// ByRef(derived) +// +// N.B. ByRef is redundant with std::ref, std::cref and std::reference_wrapper. +// However, it may still be used for consistency with ByMove(). +template +inline ::std::reference_wrapper ByRef(T& l_value) { // NOLINT + return ::std::reference_wrapper(l_value); +} + +// The ReturnNew(a1, a2, ..., a_k) action returns a pointer to a new +// instance of type T, constructed on the heap with constructor arguments +// a1, a2, ..., and a_k. The caller assumes ownership of the returned value. +template +internal::ReturnNewAction::type...> ReturnNew( + Params&&... params) { + return {std::forward_as_tuple(std::forward(params)...)}; +} + +// Action ReturnArg() returns the k-th argument of the mock function. +template +internal::ReturnArgAction ReturnArg() { + return {}; +} + +// Action SaveArg(pointer) saves the k-th (0-based) argument of the +// mock function to *pointer. +template +internal::SaveArgAction SaveArg(Ptr pointer) { + return {pointer}; +} + +// Action SaveArgPointee(pointer) saves the value pointed to +// by the k-th (0-based) argument of the mock function to *pointer. +template +internal::SaveArgPointeeAction SaveArgPointee(Ptr pointer) { + return {pointer}; +} + +// Action SetArgReferee(value) assigns 'value' to the variable +// referenced by the k-th (0-based) argument of the mock function. +template +internal::SetArgRefereeAction::type> SetArgReferee( + T&& value) { + return {std::forward(value)}; +} + +// Action SetArrayArgument(first, last) copies the elements in +// source range [first, last) to the array pointed to by the k-th +// (0-based) argument, which can be either a pointer or an +// iterator. The action does not take ownership of the elements in the +// source range. +template +internal::SetArrayArgumentAction SetArrayArgument(I1 first, + I2 last) { + return {first, last}; +} + +// Action DeleteArg() deletes the k-th (0-based) argument of the mock +// function. +template +internal::DeleteArgAction DeleteArg() { + return {}; +} + +// This action returns the value pointed to by 'pointer'. +template +internal::ReturnPointeeAction ReturnPointee(Ptr pointer) { + return {pointer}; +} + +// Action Throw(exception) can be used in a mock function of any type +// to throw the given exception. Any copyable value can be thrown. +#if GTEST_HAS_EXCEPTIONS +template +internal::ThrowAction::type> Throw(T&& exception) { + return {std::forward(exception)}; +} +#endif // GTEST_HAS_EXCEPTIONS + +namespace internal { + +// A macro from the ACTION* family (defined later in gmock-generated-actions.h) +// defines an action that can be used in a mock function. Typically, +// these actions only care about a subset of the arguments of the mock +// function. For example, if such an action only uses the second +// argument, it can be used in any mock function that takes >= 2 +// arguments where the type of the second argument is compatible. +// +// Therefore, the action implementation must be prepared to take more +// arguments than it needs. The ExcessiveArg type is used to +// represent those excessive arguments. In order to keep the compiler +// error messages tractable, we define it in the testing namespace +// instead of testing::internal. However, this is an INTERNAL TYPE +// and subject to change without notice, so a user MUST NOT USE THIS +// TYPE DIRECTLY. +struct ExcessiveArg {}; + +// Builds an implementation of an Action<> for some particular signature, using +// a class defined by an ACTION* macro. +template +struct ActionImpl; + +template +struct ImplBase { + struct Holder { + // Allows each copy of the Action<> to get to the Impl. + explicit operator const Impl&() const { return *ptr; } + std::shared_ptr ptr; + }; + using type = typename std::conditional::value, + Impl, Holder>::type; +}; + +template +struct ActionImpl : ImplBase::type { + using Base = typename ImplBase::type; + using function_type = R(Args...); + using args_type = std::tuple; + + ActionImpl() = default; // Only defined if appropriate for Base. + explicit ActionImpl(std::shared_ptr impl) : Base{std::move(impl)} {} + + R operator()(Args&&... arg) const { + static constexpr size_t kMaxArgs = + sizeof...(Args) <= 10 ? sizeof...(Args) : 10; + return Apply(MakeIndexSequence{}, + MakeIndexSequence<10 - kMaxArgs>{}, + args_type{std::forward(arg)...}); + } + + template + R Apply(IndexSequence, IndexSequence, + const args_type& args) const { + // Impl need not be specific to the signature of action being implemented; + // only the implementing function body needs to have all of the specific + // types instantiated. Up to 10 of the args that are provided by the + // args_type get passed, followed by a dummy of unspecified type for the + // remainder up to 10 explicit args. + static constexpr ExcessiveArg kExcessArg{}; + return static_cast(*this) + .template gmock_PerformImpl< + /*function_type=*/function_type, /*return_type=*/R, + /*args_type=*/args_type, + /*argN_type=*/ + typename std::tuple_element::type...>( + /*args=*/args, std::get(args)..., + ((void)excess_id, kExcessArg)...); + } +}; + +// Stores a default-constructed Impl as part of the Action<>'s +// std::function<>. The Impl should be trivial to copy. +template +::testing::Action MakeAction() { + return ::testing::Action(ActionImpl()); +} + +// Stores just the one given instance of Impl. +template +::testing::Action MakeAction(std::shared_ptr impl) { + return ::testing::Action(ActionImpl(std::move(impl))); +} + +#define GMOCK_INTERNAL_ARG_UNUSED(i, data, el) \ + , const arg##i##_type& arg##i GTEST_ATTRIBUTE_UNUSED_ +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_ \ + const args_type& args GTEST_ATTRIBUTE_UNUSED_ GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_ARG_UNUSED, , 10) + +#define GMOCK_INTERNAL_ARG(i, data, el) , const arg##i##_type& arg##i +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_ \ + const args_type& args GMOCK_PP_REPEAT(GMOCK_INTERNAL_ARG, , 10) + +#define GMOCK_INTERNAL_TEMPLATE_ARG(i, data, el) , typename arg##i##_type +#define GMOCK_ACTION_TEMPLATE_ARGS_NAMES_ \ + GMOCK_PP_TAIL(GMOCK_PP_REPEAT(GMOCK_INTERNAL_TEMPLATE_ARG, , 10)) + +#define GMOCK_INTERNAL_TYPENAME_PARAM(i, data, param) , typename param##_type +#define GMOCK_ACTION_TYPENAME_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPENAME_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_PARAM(i, data, param) , param##_type +#define GMOCK_ACTION_TYPE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_GVALUE_PARAM(i, data, param) \ + , param##_type gmock_p##i +#define GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_GVALUE_PARAM(i, data, param) \ + , std::forward(gmock_p##i) +#define GMOCK_ACTION_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_INIT_PARAM(i, data, param) \ + , param(::std::forward(gmock_p##i)) +#define GMOCK_ACTION_INIT_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_INIT_PARAM, , params)) + +#define GMOCK_INTERNAL_FIELD_PARAM(i, data, param) param##_type param; +#define GMOCK_ACTION_FIELD_PARAMS_(params) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_FIELD_PARAM, , params) + +#define GMOCK_INTERNAL_ACTION(name, full_name, params) \ + template \ + class full_name { \ + public: \ + explicit full_name(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : impl_(std::make_shared( \ + GMOCK_ACTION_GVALUE_PARAMS_(params))) {} \ + full_name(const full_name&) = default; \ + full_name(full_name&&) noexcept = default; \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(impl_); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + explicit gmock_Impl(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : GMOCK_ACTION_INIT_PARAMS_(params) {} \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + GMOCK_ACTION_FIELD_PARAMS_(params) \ + }; \ + std::shared_ptr impl_; \ + }; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) GTEST_MUST_USE_RESULT_; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) { \ + return full_name( \ + GMOCK_ACTION_GVALUE_PARAMS_(params)); \ + } \ + template \ + template \ + return_type \ + full_name::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +} // namespace internal + +// Similar to GMOCK_INTERNAL_ACTION, but no bound parameters are stored. +#define ACTION(name) \ + class name##Action { \ + public: \ + explicit name##Action() noexcept {} \ + name##Action(const name##Action&) noexcept {} \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + }; \ + }; \ + inline name##Action name() GTEST_MUST_USE_RESULT_; \ + inline name##Action name() { return name##Action(); } \ + template \ + return_type name##Action::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +#define ACTION_P(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP, (__VA_ARGS__)) + +#define ACTION_P2(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP2, (__VA_ARGS__)) + +#define ACTION_P3(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP3, (__VA_ARGS__)) + +#define ACTION_P4(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP4, (__VA_ARGS__)) + +#define ACTION_P5(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP5, (__VA_ARGS__)) + +#define ACTION_P6(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP6, (__VA_ARGS__)) + +#define ACTION_P7(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP7, (__VA_ARGS__)) + +#define ACTION_P8(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP8, (__VA_ARGS__)) + +#define ACTION_P9(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP9, (__VA_ARGS__)) + +#define ACTION_P10(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP10, (__VA_ARGS__)) + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4100 + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ diff --git a/third_party/googletest/googlemock/include/gmock/gmock-cardinalities.h b/third_party/googletest/googlemock/include/gmock/gmock-cardinalities.h new file mode 100644 index 0000000..533e604 --- /dev/null +++ b/third_party/googletest/googlemock/include/gmock/gmock-cardinalities.h @@ -0,0 +1,159 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements some commonly used cardinalities. More +// cardinalities can be defined by the user implementing the +// CardinalityInterface interface if necessary. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ + +#include + +#include +#include // NOLINT + +#include "gmock/internal/gmock-port.h" +#include "gtest/gtest.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { + +// To implement a cardinality Foo, define: +// 1. a class FooCardinality that implements the +// CardinalityInterface interface, and +// 2. a factory function that creates a Cardinality object from a +// const FooCardinality*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Cardinality objects can now be copied like plain values. + +// The implementation of a cardinality. +class CardinalityInterface { + public: + virtual ~CardinalityInterface() = default; + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + virtual int ConservativeLowerBound() const { return 0; } + virtual int ConservativeUpperBound() const { return INT_MAX; } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(::std::ostream* os) const = 0; +}; + +// A Cardinality is a copyable and IMMUTABLE (except by assignment) +// object that specifies how many times a mock function is expected to +// be called. The implementation of Cardinality is just a std::shared_ptr +// to const CardinalityInterface. Don't inherit from Cardinality! +class GTEST_API_ Cardinality { + public: + // Constructs a null cardinality. Needed for storing Cardinality + // objects in STL containers. + Cardinality() = default; + + // Constructs a Cardinality from its implementation. + explicit Cardinality(const CardinalityInterface* impl) : impl_(impl) {} + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + int ConservativeLowerBound() const { return impl_->ConservativeLowerBound(); } + int ConservativeUpperBound() const { return impl_->ConservativeUpperBound(); } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + bool IsSatisfiedByCallCount(int call_count) const { + return impl_->IsSatisfiedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will saturate this + // cardinality. + bool IsSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will over-saturate this + // cardinality, i.e. exceed the maximum number of allowed calls. + bool IsOverSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count) && + !impl_->IsSatisfiedByCallCount(call_count); + } + + // Describes self to an ostream + void DescribeTo(::std::ostream* os) const { impl_->DescribeTo(os); } + + // Describes the given actual call count to an ostream. + static void DescribeActualCallCountTo(int actual_call_count, + ::std::ostream* os); + + private: + std::shared_ptr impl_; +}; + +// Creates a cardinality that allows at least n calls. +GTEST_API_ Cardinality AtLeast(int n); + +// Creates a cardinality that allows at most n calls. +GTEST_API_ Cardinality AtMost(int n); + +// Creates a cardinality that allows any number of calls. +GTEST_API_ Cardinality AnyNumber(); + +// Creates a cardinality that allows between min and max calls. +GTEST_API_ Cardinality Between(int min, int max); + +// Creates a cardinality that allows exactly n calls. +GTEST_API_ Cardinality Exactly(int n); + +// Creates a cardinality from its implementation. +inline Cardinality MakeCardinality(const CardinalityInterface* c) { + return Cardinality(c); +} + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ diff --git a/third_party/googletest/googlemock/include/gmock/gmock-function-mocker.h b/third_party/googletest/googlemock/include/gmock/gmock-function-mocker.h new file mode 100644 index 0000000..1a1f126 --- /dev/null +++ b/third_party/googletest/googlemock/include/gmock/gmock-function-mocker.h @@ -0,0 +1,518 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements MOCK_METHOD. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ + +#include // IWYU pragma: keep +#include // IWYU pragma: keep + +#include "gmock/gmock-spec-builders.h" +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-pp.h" + +namespace testing { +namespace internal { +template +using identity_t = T; + +template +struct ThisRefAdjuster { + template + using AdjustT = typename std::conditional< + std::is_const::type>::value, + typename std::conditional::value, + const T&, const T&&>::type, + typename std::conditional::value, T&, + T&&>::type>::type; + + template + static AdjustT Adjust(const MockType& mock) { + return static_cast>(const_cast(mock)); + } +}; + +constexpr bool PrefixOf(const char* a, const char* b) { + return *a == 0 || (*a == *b && internal::PrefixOf(a + 1, b + 1)); +} + +template +constexpr bool StartsWith(const char (&prefix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(prefix, str); +} + +template +constexpr bool EndsWith(const char (&suffix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(suffix, str + M - N); +} + +template +constexpr bool Equals(const char (&a)[N], const char (&b)[M]) { + return N == M && internal::PrefixOf(a, b); +} + +template +constexpr bool ValidateSpec(const char (&spec)[N]) { + return internal::Equals("const", spec) || + internal::Equals("override", spec) || + internal::Equals("final", spec) || + internal::Equals("noexcept", spec) || + (internal::StartsWith("noexcept(", spec) && + internal::EndsWith(")", spec)) || + internal::Equals("ref(&)", spec) || + internal::Equals("ref(&&)", spec) || + (internal::StartsWith("Calltype(", spec) && + internal::EndsWith(")", spec)); +} + +} // namespace internal + +// The style guide prohibits "using" statements in a namespace scope +// inside a header file. However, the FunctionMocker class template +// is meant to be defined in the ::testing namespace. The following +// line is just a trick for working around a bug in MSVC 8.0, which +// cannot handle it if we define FunctionMocker in ::testing. +using internal::FunctionMocker; +} // namespace testing + +#define MOCK_METHOD(...) \ + GMOCK_INTERNAL_WARNING_PUSH() \ + GMOCK_INTERNAL_WARNING_CLANG(ignored, "-Wunused-member-function") \ + GMOCK_PP_VARIADIC_CALL(GMOCK_INTERNAL_MOCK_METHOD_ARG_, __VA_ARGS__) \ + GMOCK_INTERNAL_WARNING_POP() + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_1(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_2(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_3(_Ret, _MethodName, _Args) \ + GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, ()) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, _Spec) \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Args); \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Spec); \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + GMOCK_PP_NARG0 _Args, GMOCK_INTERNAL_SIGNATURE(_Ret, _Args)); \ + GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + GMOCK_PP_NARG0 _Args, _MethodName, GMOCK_INTERNAL_HAS_CONST(_Spec), \ + GMOCK_INTERNAL_HAS_OVERRIDE(_Spec), GMOCK_INTERNAL_HAS_FINAL(_Spec), \ + GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_REF_SPEC(_Spec), \ + (GMOCK_INTERNAL_SIGNATURE(_Ret, _Args))) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_5(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_6(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_7(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_WRONG_ARITY(...) \ + static_assert( \ + false, \ + "MOCK_METHOD must be called with 3 or 4 arguments. _Ret, " \ + "_MethodName, _Args and optionally _Spec. _Args and _Spec must be " \ + "enclosed in parentheses. If _Ret is a type with unprotected commas, " \ + "it must also be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Tuple) \ + static_assert( \ + GMOCK_PP_IS_ENCLOSED_PARENS(_Tuple), \ + GMOCK_PP_STRINGIZE(_Tuple) " should be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE(_N, ...) \ + static_assert( \ + std::is_function<__VA_ARGS__>::value, \ + "Signature must be a function type, maybe return type contains " \ + "unprotected comma."); \ + static_assert( \ + ::testing::tuple_size::ArgumentTuple>::value == _N, \ + "This method does not take " GMOCK_PP_STRINGIZE( \ + _N) " arguments. Parenthesize all types with unprotected commas.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT, ~, _Spec) + +#define GMOCK_INTERNAL_MOCK_METHOD_IMPL(_N, _MethodName, _Constness, \ + _Override, _Final, _NoexceptSpec, \ + _CallType, _RefSpec, _Signature) \ + typename ::testing::internal::Function::Result \ + GMOCK_INTERNAL_EXPAND(_CallType) \ + _MethodName(GMOCK_PP_REPEAT(GMOCK_INTERNAL_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) \ + _RefSpec _NoexceptSpec GMOCK_PP_IF(_Override, override, ) \ + GMOCK_PP_IF(_Final, final, ) { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .SetOwnerAndName(this, #_MethodName); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .Invoke(GMOCK_PP_REPEAT(GMOCK_INTERNAL_FORWARD_ARG, _Signature, _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) _RefSpec { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName).RegisterOwner(this); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .With(GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_ARGUMENT, , _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + const ::testing::internal::WithoutMatchers&, \ + GMOCK_PP_IF(_Constness, const, )::testing::internal::Function< \ + GMOCK_PP_REMOVE_PARENS(_Signature)>*) const _RefSpec _NoexceptSpec { \ + return ::testing::internal::ThisRefAdjuster::Adjust(*this) \ + .gmock_##_MethodName(GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_A_MATCHER_ARGUMENT, _Signature, _N)); \ + } \ + mutable ::testing::FunctionMocker \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) + +#define GMOCK_INTERNAL_EXPAND(...) __VA_ARGS__ + +// Valid modifiers. +#define GMOCK_INTERNAL_HAS_CONST(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_CONST, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_OVERRIDE(_Tuple) \ + GMOCK_PP_HAS_COMMA( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_OVERRIDE, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_FINAL(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_FINAL, ~, _Tuple)) + +#define GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT, ~, _Tuple) + +#define GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)), \ + _elem, ) + +#define GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE, ~, _Tuple) + +#define GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#define GMOCK_INTERNAL_GET_REF_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_REF_SPEC_IF_REF, ~, _Tuple) + +#define GMOCK_INTERNAL_REF_SPEC_IF_REF(_i, _, _elem) \ + GMOCK_PP_IF(GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#ifdef GMOCK_INTERNAL_STRICT_SPEC_ASSERT +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + ::testing::internal::ValidateSpec(GMOCK_PP_STRINGIZE(_elem)), \ + "Token \'" GMOCK_PP_STRINGIZE( \ + _elem) "\' cannot be recognized as a valid specification " \ + "modifier. Is a ',' missing?"); +#else +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + (GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem))) == 1, \ + GMOCK_PP_STRINGIZE( \ + _elem) " cannot be recognized as a valid specification modifier."); +#endif // GMOCK_INTERNAL_STRICT_SPEC_ASSERT + +// Modifiers implementation. +#define GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CONST_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CONST_I_const , + +#define GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_OVERRIDE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_OVERRIDE_I_override , + +#define GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_FINAL_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_FINAL_I_final , + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_NOEXCEPT_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT_I_noexcept , + +#define GMOCK_INTERNAL_DETECT_REF(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_REF_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_REF_I_ref , + +#define GMOCK_INTERNAL_UNPACK_ref(x) x + +#define GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CALLTYPE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CALLTYPE_I_Calltype , + +#define GMOCK_INTERNAL_UNPACK_Calltype(...) __VA_ARGS__ + +// Note: The use of `identity_t` here allows _Ret to represent return types that +// would normally need to be specified in a different way. For example, a method +// returning a function pointer must be written as +// +// fn_ptr_return_t (*method(method_args_t...))(fn_ptr_args_t...) +// +// But we only support placing the return type at the beginning. To handle this, +// we wrap all calls in identity_t, so that a declaration will be expanded to +// +// identity_t method(method_args_t...) +// +// This allows us to work around the syntactic oddities of function/method +// types. +#define GMOCK_INTERNAL_SIGNATURE(_Ret, _Args) \ + ::testing::internal::identity_t( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GET_TYPE, _, _Args)) + +#define GMOCK_INTERNAL_GET_TYPE(_i, _, _elem) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_PP_IF(GMOCK_PP_IS_BEGIN_PARENS(_elem), GMOCK_PP_REMOVE_PARENS, \ + GMOCK_PP_IDENTITY) \ + (_elem) + +#define GMOCK_INTERNAL_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_ARG_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_FORWARD_ARG(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::std::forward(gmock_a##_i) + +#define GMOCK_INTERNAL_MATCHER_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_MATCHER_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_MATCHER_ARGUMENT(_i, _1, _2) \ + GMOCK_PP_COMMA_IF(_i) \ + gmock_a##_i + +#define GMOCK_INTERNAL_A_MATCHER_ARGUMENT(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::testing::A() + +#define GMOCK_INTERNAL_ARG_O(_i, ...) \ + typename ::testing::internal::Function<__VA_ARGS__>::template Arg<_i>::type + +#define GMOCK_INTERNAL_MATCHER_O(_i, ...) \ + const ::testing::Matcher::template Arg<_i>::type>& + +#define MOCK_METHOD0(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 0, __VA_ARGS__) +#define MOCK_METHOD1(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 1, __VA_ARGS__) +#define MOCK_METHOD2(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 2, __VA_ARGS__) +#define MOCK_METHOD3(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 3, __VA_ARGS__) +#define MOCK_METHOD4(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 4, __VA_ARGS__) +#define MOCK_METHOD5(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 5, __VA_ARGS__) +#define MOCK_METHOD6(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 6, __VA_ARGS__) +#define MOCK_METHOD7(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 7, __VA_ARGS__) +#define MOCK_METHOD8(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 8, __VA_ARGS__) +#define MOCK_METHOD9(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 9, __VA_ARGS__) +#define MOCK_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, , m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T(m, ...) MOCK_METHOD0(m, __VA_ARGS__) +#define MOCK_METHOD1_T(m, ...) MOCK_METHOD1(m, __VA_ARGS__) +#define MOCK_METHOD2_T(m, ...) MOCK_METHOD2(m, __VA_ARGS__) +#define MOCK_METHOD3_T(m, ...) MOCK_METHOD3(m, __VA_ARGS__) +#define MOCK_METHOD4_T(m, ...) MOCK_METHOD4(m, __VA_ARGS__) +#define MOCK_METHOD5_T(m, ...) MOCK_METHOD5(m, __VA_ARGS__) +#define MOCK_METHOD6_T(m, ...) MOCK_METHOD6(m, __VA_ARGS__) +#define MOCK_METHOD7_T(m, ...) MOCK_METHOD7(m, __VA_ARGS__) +#define MOCK_METHOD8_T(m, ...) MOCK_METHOD8(m, __VA_ARGS__) +#define MOCK_METHOD9_T(m, ...) MOCK_METHOD9(m, __VA_ARGS__) +#define MOCK_METHOD10_T(m, ...) MOCK_METHOD10(m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T(m, ...) MOCK_CONST_METHOD0(m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T(m, ...) MOCK_CONST_METHOD1(m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T(m, ...) MOCK_CONST_METHOD2(m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T(m, ...) MOCK_CONST_METHOD3(m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T(m, ...) MOCK_CONST_METHOD4(m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T(m, ...) MOCK_CONST_METHOD5(m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T(m, ...) MOCK_CONST_METHOD6(m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T(m, ...) MOCK_CONST_METHOD7(m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T(m, ...) MOCK_CONST_METHOD8(m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T(m, ...) MOCK_CONST_METHOD9(m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T(m, ...) MOCK_CONST_METHOD10(m, __VA_ARGS__) + +#define MOCK_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 0, __VA_ARGS__) +#define MOCK_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 1, __VA_ARGS__) +#define MOCK_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 2, __VA_ARGS__) +#define MOCK_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 3, __VA_ARGS__) +#define MOCK_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 4, __VA_ARGS__) +#define MOCK_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 5, __VA_ARGS__) +#define MOCK_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 6, __VA_ARGS__) +#define MOCK_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 7, __VA_ARGS__) +#define MOCK_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 8, __VA_ARGS__) +#define MOCK_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 9, __VA_ARGS__) +#define MOCK_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHODN(constness, ct, Method, args_num, ...) \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + args_num, ::testing::internal::identity_t<__VA_ARGS__>); \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + args_num, Method, GMOCK_PP_NARG0(constness), 0, 0, , ct, , \ + (::testing::internal::identity_t<__VA_ARGS__>)) + +#define GMOCK_MOCKER_(arity, constness, Method) \ + GTEST_CONCAT_TOKEN_(gmock##constness##arity##_##Method##_, __LINE__) + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ diff --git a/third_party/googletest/googlemock/include/gmock/gmock-matchers.h b/third_party/googletest/googlemock/include/gmock/gmock-matchers.h new file mode 100644 index 0000000..0f67713 --- /dev/null +++ b/third_party/googletest/googlemock/include/gmock/gmock-matchers.h @@ -0,0 +1,5623 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The MATCHER* family of macros can be used in a namespace scope to +// define custom matchers easily. +// +// Basic Usage +// =========== +// +// The syntax +// +// MATCHER(name, description_string) { statements; } +// +// defines a matcher with the given name that executes the statements, +// which must return a bool to indicate if the match succeeds. Inside +// the statements, you can refer to the value being matched by 'arg', +// and refer to its type by 'arg_type'. +// +// The description string documents what the matcher does, and is used +// to generate the failure message when the match fails. Since a +// MATCHER() is usually defined in a header file shared by multiple +// C++ source files, we require the description to be a C-string +// literal to avoid possible side effects. It can be empty, in which +// case we'll use the sequence of words in the matcher name as the +// description. +// +// For example: +// +// MATCHER(IsEven, "") { return (arg % 2) == 0; } +// +// allows you to write +// +// // Expects mock_foo.Bar(n) to be called where n is even. +// EXPECT_CALL(mock_foo, Bar(IsEven())); +// +// or, +// +// // Verifies that the value of some_expression is even. +// EXPECT_THAT(some_expression, IsEven()); +// +// If the above assertion fails, it will print something like: +// +// Value of: some_expression +// Expected: is even +// Actual: 7 +// +// where the description "is even" is automatically calculated from the +// matcher name IsEven. +// +// Argument Type +// ============= +// +// Note that the type of the value being matched (arg_type) is +// determined by the context in which you use the matcher and is +// supplied to you by the compiler, so you don't need to worry about +// declaring it (nor can you). This allows the matcher to be +// polymorphic. For example, IsEven() can be used to match any type +// where the value of "(arg % 2) == 0" can be implicitly converted to +// a bool. In the "Bar(IsEven())" example above, if method Bar() +// takes an int, 'arg_type' will be int; if it takes an unsigned long, +// 'arg_type' will be unsigned long; and so on. +// +// Parameterizing Matchers +// ======================= +// +// Sometimes you'll want to parameterize the matcher. For that you +// can use another macro: +// +// MATCHER_P(name, param_name, description_string) { statements; } +// +// For example: +// +// MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +// +// will allow you to write: +// +// EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +// +// which may lead to this message (assuming n is 10): +// +// Value of: Blah("a") +// Expected: has absolute value 10 +// Actual: -9 +// +// Note that both the matcher description and its parameter are +// printed, making the message human-friendly. +// +// In the matcher definition body, you can write 'foo_type' to +// reference the type of a parameter named 'foo'. For example, in the +// body of MATCHER_P(HasAbsoluteValue, value) above, you can write +// 'value_type' to refer to the type of 'value'. +// +// We also provide MATCHER_P2, MATCHER_P3, ..., up to MATCHER_P$n to +// support multi-parameter matchers. +// +// Describing Parameterized Matchers +// ================================= +// +// The last argument to MATCHER*() is a string-typed expression. The +// expression can reference all of the matcher's parameters and a +// special bool-typed variable named 'negation'. When 'negation' is +// false, the expression should evaluate to the matcher's description; +// otherwise it should evaluate to the description of the negation of +// the matcher. For example, +// +// using testing::PrintToString; +// +// MATCHER_P2(InClosedRange, low, hi, +// std::string(negation ? "is not" : "is") + " in range [" + +// PrintToString(low) + ", " + PrintToString(hi) + "]") { +// return low <= arg && arg <= hi; +// } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: is in range [4, 6] +// ... +// Expected: is not in range [2, 4] +// +// If you specify "" as the description, the failure message will +// contain the sequence of words in the matcher name followed by the +// parameter values printed as a tuple. For example, +// +// MATCHER_P2(InClosedRange, low, hi, "") { ... } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: in closed range (4, 6) +// ... +// Expected: not (in closed range (2, 4)) +// +// Types of Matcher Parameters +// =========================== +// +// For the purpose of typing, you can view +// +// MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +// +// as shorthand for +// +// template +// FooMatcherPk +// Foo(p1_type p1, ..., pk_type pk) { ... } +// +// When you write Foo(v1, ..., vk), the compiler infers the types of +// the parameters v1, ..., and vk for you. If you are not happy with +// the result of the type inference, you can specify the types by +// explicitly instantiating the template, as in Foo(5, +// false). As said earlier, you don't get to (or need to) specify +// 'arg_type' as that's determined by the context in which the matcher +// is used. You can assign the result of expression Foo(p1, ..., pk) +// to a variable of type FooMatcherPk. This +// can be useful when composing matchers. +// +// While you can instantiate a matcher template with reference types, +// passing the parameters by pointer usually makes your code more +// readable. If, however, you still want to pass a parameter by +// reference, be aware that in the failure message generated by the +// matcher you will see the value of the referenced object but not its +// address. +// +// Explaining Match Results +// ======================== +// +// Sometimes the matcher description alone isn't enough to explain why +// the match has failed or succeeded. For example, when expecting a +// long string, it can be very helpful to also print the diff between +// the expected string and the actual one. To achieve that, you can +// optionally stream additional information to a special variable +// named result_listener, whose type is a pointer to class +// MatchResultListener: +// +// MATCHER_P(EqualsLongString, str, "") { +// if (arg == str) return true; +// +// *result_listener << "the difference: " +/// << DiffStrings(str, arg); +// return false; +// } +// +// Overloading Matchers +// ==================== +// +// You can overload matchers with different numbers of parameters: +// +// MATCHER_P(Blah, a, description_string1) { ... } +// MATCHER_P2(Blah, a, b, description_string2) { ... } +// +// Caveats +// ======= +// +// When defining a new matcher, you should also consider implementing +// MatcherInterface or using MakePolymorphicMatcher(). These +// approaches require more work than the MATCHER* macros, but also +// give you more control on the types of the value being matched and +// the matcher parameters, which may leads to better compiler error +// messages when the matcher is used wrong. They also allow +// overloading matchers based on parameter types (as opposed to just +// based on the number of parameters). +// +// MATCHER*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// +// More Information +// ================ +// +// To learn more about using these macros, please search for 'MATCHER' +// on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md +// +// This file also implements some commonly used argument matchers. More +// matchers can be defined by the user implementing the +// MatcherInterface interface if necessary. +// +// See googletest/include/gtest/gtest-matchers.h for the definition of class +// Matcher, class MatcherInterface, and others. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" +#include "gtest/gtest.h" + +// MSVC warning C5046 is new as of VS2017 version 15.8. +#if defined(_MSC_VER) && _MSC_VER >= 1915 +#define GMOCK_MAYBE_5046_ 5046 +#else +#define GMOCK_MAYBE_5046_ +#endif + +GTEST_DISABLE_MSC_WARNINGS_PUSH_( + 4251 GMOCK_MAYBE_5046_ /* class A needs to have dll-interface to be used by + clients of class B */ + /* Symbol involving type with internal linkage not defined */) + +namespace testing { + +// To implement a matcher Foo for type T, define: +// 1. a class FooMatcherImpl that implements the +// MatcherInterface interface, and +// 2. a factory function that creates a Matcher object from a +// FooMatcherImpl*. +// +// The two-level delegation design makes it possible to allow a user +// to write "v" instead of "Eq(v)" where a Matcher is expected, which +// is impossible if we pass matchers by pointers. It also eases +// ownership management as Matcher objects can now be copied like +// plain values. + +// A match result listener that stores the explanation in a string. +class StringMatchResultListener : public MatchResultListener { + public: + StringMatchResultListener() : MatchResultListener(&ss_) {} + + // Returns the explanation accumulated so far. + std::string str() const { return ss_.str(); } + + // Clears the explanation accumulated so far. + void Clear() { ss_.str(""); } + + private: + ::std::stringstream ss_; + + StringMatchResultListener(const StringMatchResultListener&) = delete; + StringMatchResultListener& operator=(const StringMatchResultListener&) = + delete; +}; + +// Anything inside the 'internal' namespace IS INTERNAL IMPLEMENTATION +// and MUST NOT BE USED IN USER CODE!!! +namespace internal { + +// The MatcherCastImpl class template is a helper for implementing +// MatcherCast(). We need this helper in order to partially +// specialize the implementation of MatcherCast() (C++ allows +// class/struct templates to be partially specialized, but not +// function templates.). + +// This general version is used when MatcherCast()'s argument is a +// polymorphic matcher (i.e. something that can be converted to a +// Matcher but is not one yet; for example, Eq(value)) or a value (for +// example, "hello"). +template +class MatcherCastImpl { + public: + static Matcher Cast(const M& polymorphic_matcher_or_value) { + // M can be a polymorphic matcher, in which case we want to use + // its conversion operator to create Matcher. Or it can be a value + // that should be passed to the Matcher's constructor. + // + // We can't call Matcher(polymorphic_matcher_or_value) when M is a + // polymorphic matcher because it'll be ambiguous if T has an implicit + // constructor from M (this usually happens when T has an implicit + // constructor from any type). + // + // It won't work to unconditionally implicit_cast + // polymorphic_matcher_or_value to Matcher because it won't trigger + // a user-defined conversion from M to T if one exists (assuming M is + // a value). + return CastImpl(polymorphic_matcher_or_value, + std::is_convertible>{}, + std::is_convertible{}); + } + + private: + template + static Matcher CastImpl(const M& polymorphic_matcher_or_value, + std::true_type /* convertible_to_matcher */, + std::integral_constant) { + // M is implicitly convertible to Matcher, which means that either + // M is a polymorphic matcher or Matcher has an implicit constructor + // from M. In both cases using the implicit conversion will produce a + // matcher. + // + // Even if T has an implicit constructor from M, it won't be called because + // creating Matcher would require a chain of two user-defined conversions + // (first to create T from M and then to create Matcher from T). + return polymorphic_matcher_or_value; + } + + // M can't be implicitly converted to Matcher, so M isn't a polymorphic + // matcher. It's a value of a type implicitly convertible to T. Use direct + // initialization to create a matcher. + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::true_type /* convertible_to_T */) { + return Matcher(ImplicitCast_(value)); + } + + // M can't be implicitly converted to either Matcher or T. Attempt to use + // polymorphic matcher Eq(value) in this case. + // + // Note that we first attempt to perform an implicit cast on the value and + // only fall back to the polymorphic Eq() matcher afterwards because the + // latter calls bool operator==(const Lhs& lhs, const Rhs& rhs) in the end + // which might be undefined even when Rhs is implicitly convertible to Lhs + // (e.g. std::pair vs. std::pair). + // + // We don't define this method inline as we need the declaration of Eq(). + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::false_type /* convertible_to_T */); +}; + +// This more specialized version is used when MatcherCast()'s argument +// is already a Matcher. This only compiles when type T can be +// statically converted to type U. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& source_matcher) { + return Matcher(new Impl(source_matcher)); + } + + private: + class Impl : public MatcherInterface { + public: + explicit Impl(const Matcher& source_matcher) + : source_matcher_(source_matcher) {} + + // We delegate the matching logic to the source matcher. + bool MatchAndExplain(T x, MatchResultListener* listener) const override { + using FromType = typename std::remove_cv::type>::type>::type; + using ToType = typename std::remove_cv::type>::type>::type; + // Do not allow implicitly converting base*/& to derived*/&. + static_assert( + // Do not trigger if only one of them is a pointer. That implies a + // regular conversion and not a down_cast. + (std::is_pointer::type>::value != + std::is_pointer::type>::value) || + std::is_same::value || + !std::is_base_of::value, + "Can't implicitly convert from to "); + + // Do the cast to `U` explicitly if necessary. + // Otherwise, let implicit conversions do the trick. + using CastType = + typename std::conditional::value, + T&, U>::type; + + return source_matcher_.MatchAndExplain(static_cast(x), + listener); + } + + void DescribeTo(::std::ostream* os) const override { + source_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + source_matcher_.DescribeNegationTo(os); + } + + private: + const Matcher source_matcher_; + }; +}; + +// This even more specialized version is used for efficiently casting +// a matcher to its own type. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& matcher) { return matcher; } +}; + +// Template specialization for parameterless Matcher. +template +class MatcherBaseImpl { + public: + MatcherBaseImpl() = default; + + template + operator ::testing::Matcher() const { // NOLINT(runtime/explicit) + return ::testing::Matcher(new + typename Derived::template gmock_Impl()); + } +}; + +// Template specialization for Matcher with parameters. +template