refactor: scan classDefs in advance.

This commit is contained in:
2025-01-30 20:56:00 +08:00
parent 79d65ebc0c
commit e64e8600c5
3 changed files with 129 additions and 128 deletions

View File

@@ -1,82 +0,0 @@
import util.cpp_language as CppUtil
_need_fix_includes_queue = dict()
_need_fix_members_queue = dict()
_class_defs_record = dict()
class ClassDefineRecord:
rpath = str()
is_template = bool()
is_empty = bool()
def __init__(self, rpath: str, is_template: bool, is_empty: bool):
self.rpath = rpath
self.is_template = is_template
self.is_empty = is_empty
def record_class_definition(
path: str, namespace: str, class_name: str, is_template: bool, is_empty: bool
):
assert len(path) > 0 and len(class_name) > 0
assert '::' not in class_name # c++ does not support forward declaration for nested class.
if namespace not in _class_defs_record:
_class_defs_record[namespace] = {}
assert class_name not in _class_defs_record[namespace], (
f'path = {path}, ns = {namespace}, cl = {class_name}'
)
_class_defs_record[namespace][class_name] = ClassDefineRecord(
path[path.find('src/') + 4 :], is_template, is_empty
)
def add_pending_fix_includes_queue(path: str, decls: list, member_typeset: list):
assert path not in _need_fix_includes_queue
if len(decls) > 0 and len(member_typeset) > 0:
_need_fix_includes_queue[path] = [decls, member_typeset]
def add_pending_fix_members_queue(path: str, member_typeset: list):
assert path not in _need_fix_members_queue
if len(member_typeset) > 0:
_need_fix_members_queue[path] = member_typeset
def _find_definition(decl: str, in_types: list) -> ClassDefineRecord | None:
find_decl = CppUtil.find_class_forward_declaration(decl)
assert find_decl
namespace = find_decl.namespace_decl
clazz = find_decl.class_decl
if not CppUtil.is_full_type_required_for_typeset(namespace, clazz, in_types):
return None
assert namespace in _class_defs_record, f'namespace = "{namespace}" is not recorded.'
assert clazz in _class_defs_record[namespace], (
f'namespace = "{namespace}", class = {clazz} is not recorded.'
)
return _class_defs_record[namespace][find_decl.class_decl]
def process():
# fix includes
for path, decl_and_types in _need_fix_includes_queue.items():
with open(path, 'r', encoding='utf-8') as file:
content = file.read()
for decl in decl_and_types[0]:
record = _find_definition(decl, decl_and_types[1])
if record:
include = f'#include "{record.rpath}"'
content = content.replace(decl, include)
with open(path, 'w', encoding='utf-8') as wfile:
wfile.write(content)
# fix members
# for path, types in _need_fix_members_queue.items():
# with open(path, 'r', encoding='utf-8') as file:
# content = file.read()
# with open(path, 'w', encoding='utf-8') as wfile:
# wfile.write(content)

View File

@@ -0,0 +1,85 @@
import os
import util.cpp_language as CppUtil
import util.string as StrUtil
# storage
defined_classes = dict()
class ClassDefine:
rpath = str()
is_template = bool()
is_empty = bool()
def __init__(self, rpath: str, is_template: bool, is_empty: bool):
self.rpath = rpath
self.is_template = is_template
self.is_empty = is_empty
def add_class_record(path: str, namespace: str, class_name: str, is_template: bool, is_empty: bool):
assert len(path) > 0 and len(class_name) > 0
assert '::' not in class_name # c++ does not support forward declaration for nested class.
if namespace not in defined_classes:
defined_classes[namespace] = {}
assert class_name not in defined_classes[namespace], (
f'path = {path}, ns = {namespace}, cl = {class_name}'
)
defined_classes[namespace][class_name] = ClassDefine(
path[path.find('src/') + 4 :], is_template, is_empty
)
def query_class_record_strict(namespace_decl: str, class_decl: str) -> ClassDefine:
assert namespace_decl in defined_classes, f'namespace = "{namespace_decl}" is not recorded.'
assert class_decl in defined_classes[namespace_decl], (
f'namespace = "{namespace_decl}", class = {class_decl} is not recorded.'
)
return defined_classes[namespace_decl][class_decl]
def process(path_to_file: str):
assert os.path.isfile(path_to_file)
if path_to_file.endswith('_HeaderOutputPredefine.h'):
return
with open(path_to_file, 'r', encoding='utf-8') as file:
# states
in_forward_declaration_list = False
current_namespace = []
# tmp
content = ''
for line in file.readlines():
stripped_line = line.strip()
# record forward declarations
if stripped_line.startswith('// auto generated forward declare list'):
in_forward_declaration_list = True
if stripped_line.startswith('// clang-format on') and in_forward_declaration_list:
in_forward_declaration_list = False
# record namespace & classes
if not in_forward_declaration_list:
if StrUtil.startswith_m(line, 'class ', 'struct ', 'union '): # ignore nested class
founded = CppUtil.find_class_definition(line)
if founded:
is_template = (
content[content.rfind('\n', 0, -1) :].strip().startswith('template ')
)
is_empty = stripped_line.endswith('{};')
add_class_record(
path_to_file,
'::'.join(current_namespace),
founded,
is_template,
is_empty,
)
founded = CppUtil.find_namespace_declaration(line)
if founded:
current_namespace.append(founded)
if '// namespace' in stripped_line:
current_namespace.pop()

View File

@@ -1,7 +1,7 @@
import os
import re
import header_postprocessor as HeaderPostProcessor
import header_preprocessor as HeaderPreProcessor
import util.cpp_language as CppUtil
import util.string as StrUtil
@@ -57,6 +57,21 @@ class Options:
self.set_variable(opt)
def try_translate_forward_declaration(
decl: str, typeset: list
) -> HeaderPreProcessor.ClassDefine | None:
find_decl = CppUtil.find_class_forward_declaration(decl)
assert find_decl, f'decl = {decl}'
namespace = find_decl.namespace_decl
clazz = find_decl.class_decl
if not CppUtil.is_full_type_required_for_typeset(namespace, clazz, typeset):
return None
return HeaderPreProcessor.query_class_record_strict(namespace, clazz)
def process(path_to_file: str, args: Options):
assert os.path.isfile(path_to_file)
@@ -81,7 +96,6 @@ def process(path_to_file: str, args: Options):
in_member_variable = False
in_forward_declaration_list = False
has_typed_storage = False
current_namespace = []
forward_declarations = []
member_variable_types = []
@@ -158,6 +172,8 @@ def process(path_to_file: str, args: Options):
type_name = matched[3]
var_name = matched[4]
member_variable_types.append(type_name)
security_check = ''
if args.add_sizeof_alignof_static_assertions:
security_check += f'\tstatic_assert(sizeof({var_name}) == {size});\n'
@@ -188,7 +204,6 @@ def process(path_to_file: str, args: Options):
content += f'\t{type_name} {var_name};\n{security_check}'
member_variable_types.append(type_name)
in_member_variable = False
continue
@@ -210,37 +225,16 @@ def process(path_to_file: str, args: Options):
assert False, 'unreachable'
# record forward declarations
# fix forward declarations
if stripped_line.startswith('// auto generated forward declare list'):
in_forward_declaration_list = True
if in_forward_declaration_list:
if StrUtil.startswith_m(stripped_line, 'class ', 'struct ', 'union ', 'namespace '):
forward_declarations.append(stripped_line)
if in_forward_declaration_list and StrUtil.startswith_m(
stripped_line, 'class ', 'struct ', 'union ', 'namespace '
):
forward_declarations.append(stripped_line)
if stripped_line.startswith('// clang-format on') and in_forward_declaration_list:
in_forward_declaration_list = False
# record namespace & classes
if not in_forward_declaration_list:
if StrUtil.startswith_m(line, 'class ', 'struct ', 'union '): # ignore nested class
founded = CppUtil.find_class_definition(line)
if founded:
is_template = (
content[content.rfind('\n', 0, -1) :].strip().startswith('template ')
)
is_empty = stripped_line.endswith('{};')
HeaderPostProcessor.record_class_definition(
path_to_file,
'::'.join(current_namespace),
founded,
is_template,
is_empty,
)
founded = CppUtil.find_namespace_declaration(line)
if founded:
current_namespace.append(founded)
if '// namespace' in stripped_line:
current_namespace.pop()
# remove useless thunks.
if '// NOLINTEND' in line and (in_useless_thunk or in_static_variable):
in_useless_thunk = False
@@ -256,15 +250,14 @@ def process(path_to_file: str, args: Options):
content = content[: content.rfind('\n')] # for nested classes
if not in_useless_thunk:
content += line
if args.fix_includes_for_member_variables and has_typed_storage:
for decl in forward_declarations:
class_define = try_translate_forward_declaration(decl, member_variable_types)
if class_define:
is_modified = True
content = content.replace(decl, f'#include "{class_define.rpath}"\n')
continue
if is_modified:
if args.fix_includes_for_member_variables and has_typed_storage:
HeaderPostProcessor.add_pending_fix_includes_queue(
path_to_file, forward_declarations, member_variable_types
)
if args.fix_size_for_type_with_empty_template_class and has_typed_storage:
HeaderPostProcessor.add_pending_fix_members_queue(
path_to_file, member_variable_types
)
with open(path_to_file, 'w', encoding='utf-8') as wfile:
wfile.write(content)
@@ -272,13 +265,18 @@ def process(path_to_file: str, args: Options):
def iterate(args: Options):
assert os.path.isdir(args.base_dir)
for root, dirs, files in os.walk(args.base_dir):
for file in files:
if CppUtil.is_header_file(file):
path = os.path.join(root, file)
# processing: executed immediately after preprocessing is completed.
process(path, args)
def _traverse_header(path_to_dir: str, fun):
for root, dirs, files in os.walk(path_to_dir):
for file in files:
if CppUtil.is_header_file(file):
path = os.path.join(root, file)
fun(path)
# post-processing: executed after all files have been processed.
# during processing, files that require post-processing will be marked.
HeaderPostProcessor.process()
def _preprocess(path_to_file):
HeaderPreProcessor.process(path_to_file)
def _process(path_to_file):
process(path_to_file, args)
_traverse_header(args.base_dir, _preprocess)
_traverse_header(args.base_dir, _process)