-
Notifications
You must be signed in to change notification settings - Fork 160
Expand file tree
/
Copy pathinclude_checker.py
More file actions
89 lines (75 loc) · 2.46 KB
/
include_checker.py
File metadata and controls
89 lines (75 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import print_function
import sys
import re
import os
import argparse
IncludeRegex = re.compile(r"\s*#include\s*(\S+)")
RemoveComments = re.compile(r"//.*")
exclusion_regex = re.compile(r".*thirdparty.*")
def parse_args():
argparser = argparse.ArgumentParser(
"Checks for a consistent '#include' syntax"
)
argparser.add_argument(
"--regex",
type=str,
default=r"[.](cu|cuh|h|hpp|hxx|cpp)$",
help="Regex string to filter in sources",
)
argparser.add_argument(
"dirs", type=str, nargs="*", help="List of dirs where to find sources"
)
args = argparser.parse_args()
args.regex_compiled = re.compile(args.regex)
return args
def list_all_source_file(file_regex, srcdirs):
all_files = []
for srcdir in srcdirs:
for root, dirs, files in os.walk(srcdir):
for f in files:
if not re.search(exclusion_regex, root) and re.search(
file_regex, f
):
src = os.path.join(root, f)
all_files.append(src)
return all_files
def check_includes_in(src):
errs = []
dir = os.path.dirname(src)
for line_number, line in enumerate(open(src)):
line = RemoveComments.sub("", line)
match = IncludeRegex.search(line)
if match is None:
continue
val = match.group(1)
inc_file = val[1:-1] # strip out " or <
full_path = os.path.join(dir, inc_file)
line_num = line_number + 1
if val[0] == '"' and not os.path.exists(full_path):
errs.append("Line:%d use #include <...>" % line_num)
elif val[0] == "<" and os.path.exists(full_path):
errs.append('Line:%d use #include "..."' % line_num)
return errs
def main():
args = parse_args()
all_files = list_all_source_file(args.regex_compiled, args.dirs)
all_errs = {}
for f in all_files:
errs = check_includes_in(f)
if len(errs) > 0:
all_errs[f] = errs
if len(all_errs) == 0:
print("include-check PASSED")
else:
print("include-check FAILED! See below for errors...")
for f, errs in all_errs.items():
print("File: %s" % f)
for e in errs:
print(" %s" % e)
sys.exit(-1)
return
if __name__ == "__main__":
main()