feat: include fixer.

This commit is contained in:
2025-01-27 19:31:27 +08:00
parent f20050ab0a
commit da3d1897f3
5 changed files with 520 additions and 189 deletions

View File

@@ -0,0 +1,254 @@
import os
import re
import predefine_subprocessor as PredefineProcessor
import util.cpp_language as CppUtil
from include_fixer import IncludeFixer
class Options:
base_dir = str()
# functions
remove_constructor_thunk = bool()
remove_destructor_thunk = bool()
remove_virtual_function_thunk = bool()
# variables
remove_virtual_table_pointer_thunk = bool()
restore_static_variable = bool()
restore_member_variable = bool()
# others
# only takes effect for TypedStorage, since the TypedStorage wrapper makes the full type unnecessary.
fix_includes_for_member_variables = True
def __init__(self, args):
self.base_dir = args.path
self.remove_constructor_thunk = args.remove_constructor_thunk
self.remove_destructor_thunk = args.remove_destructor_thunk
self.remove_virtual_function_thunk = args.remove_virtual_function_thunk
self.remove_virtual_table_pointer_thunk = args.remove_virtual_table_pointer_thunk
self.restore_static_variable = args.restore_static_variable
self.restore_member_variable = args.restore_member_variable
if args.all:
self.set_all(True)
def set_function(self, opt: bool):
self.remove_constructor_thunk = opt
self.remove_destructor_thunk = opt
self.remove_virtual_function_thunk = opt
def set_variable(self, opt: bool):
self.remove_virtual_table_pointer_thunk = opt
self.restore_static_variable = opt
self.restore_member_variable = opt
def set_all(self, opt: bool):
self.set_function(opt)
self.set_variable(opt)
def process(path_to_file: str, include_fixer: IncludeFixer | None, args: Options):
assert os.path.isfile(path_to_file)
if path_to_file.endswith('_HeaderOutputPredefine.h'):
PredefineProcessor.process(path_to_file)
return
RECORDED_THUNKS = []
if args.remove_constructor_thunk:
RECORDED_THUNKS.append('// constructor thunks')
if args.remove_destructor_thunk:
RECORDED_THUNKS.append('// destructor thunk')
if args.remove_virtual_table_pointer_thunk:
RECORDED_THUNKS.append('// vftables')
if args.remove_virtual_function_thunk:
RECORDED_THUNKS.append('// virtual function thunks')
with open(path_to_file, 'r', encoding='utf-8') as file:
# states
is_modified = False
in_useless_thunk = False
in_static_variable = False
in_member_variable = False
in_forward_declaration_list = False
has_typed_storage = False
current_namespace = []
forward_declarations = []
member_variable_types = []
ll_typed_regex = re.compile(r'TypedStorage<(\d+), (\d+), (.*?)> (\w+);')
ll_untyped_regex = re.compile(r'UntypedStorage<(\d+), (\d+)> (\w+);')
def regex_preprocess_name(con: str):
# typed storage can be very complex, so we need to preprocess it.
# TODO: find a better way.
return re.sub(r'\s+', ' ', con).replace('< ', '<')
# tmp
content = ''
for line in file.readlines():
stripped_line = line.strip()
# restore static member variable:
if args.restore_static_variable and '// static variables' in line:
in_static_variable = True
is_modified = True
if in_static_variable and stripped_line.endswith(';'):
if not stripped_line.startswith('MCAPI'): # declaration may not be on one line
begin_pos = content.rfind('MCAPI')
stripped_line = content[begin_pos:] + stripped_line
content = content[:begin_pos]
stripped_line = stripped_line.strip()
# remove parameter list (convert to static variable)
call_spec_pos = stripped_line.rfind('()')
assert call_spec_pos != -1
stripped_line = stripped_line[:call_spec_pos] + stripped_line[call_spec_pos + 2 :]
# remove reference
refsym_pos = stripped_line.rfind('&') # T&
tlpsym_pos = stripped_line.rfind('>') # ::std::add_lvalue_reference_t<T>
assert refsym_pos != -1 or tlpsym_pos != -1, f'in {path_to_file}'
if tlpsym_pos == -1 or refsym_pos > tlpsym_pos:
stripped_line = stripped_line[:refsym_pos] + stripped_line[refsym_pos + 1 :]
elif refsym_pos == -1 or tlpsym_pos > refsym_pos:
# C-style arrays must have '[]' written after the variable name
stripped_line = stripped_line[:tlpsym_pos] + '>' + stripped_line[tlpsym_pos:]
stripped_line = stripped_line.replace(
'::std::add_lvalue_reference_t<',
'::std::remove_reference_t<::std::add_lvalue_reference_t<',
)
content += f'\t{stripped_line}\n'
continue
# restore member variable:
if args.restore_member_variable and '::ll::' in line: # union { ... };
in_member_variable = True
is_modified = True
if in_member_variable and stripped_line.endswith(';'):
if not stripped_line.startswith('::ll::'):
begin_pos = content.rfind('::ll::')
stripped_line = content[begin_pos:] + stripped_line
content = content[:begin_pos]
stripped_line = stripped_line.strip()
# ::ll::TypedStorage<Alignment, Size, T> mVar;
if 'TypedStorage' in stripped_line:
has_typed_storage = True
matched = ll_typed_regex.search(regex_preprocess_name(stripped_line))
assert matched and matched.lastindex == 4, (
f'in {path_to_file}, line="{stripped_line}"'
)
align = matched[1] # unused.
size = matched[2] # unused.
type_name = matched[3]
var_name = matched[4]
if type_name.endswith(']'): # is c-style array
array_length = int(type_name[type_name.find('[') + 1 : type_name.find(']')])
type_name = type_name[: type_name.find('[')]
var_name = f'{var_name}[{array_length}]'
fun_ptr_pos = type_name.find('(*)')
if -1 != fun_ptr_pos: # is c-style function ptr
type_name = (
type_name[: fun_ptr_pos + 2] + var_name + type_name[fun_ptr_pos + 2 :]
)
var_name = ''
content += f'\t{type_name} {var_name};\n'
member_variable_types.append(type_name)
in_member_variable = False
continue
# ::ll::UntypedStorage<Alignment, Size> mVar;
if 'UntypedStorage' in stripped_line:
matched = ll_untyped_regex.search(regex_preprocess_name(stripped_line))
assert matched and matched.lastindex == 3, (
f'in {path_to_file}, line="{stripped_line}"'
)
align = matched[1]
size = matched[2]
var_name = matched[3]
content += f'\talignas({align}) std::byte {var_name}[{size}];\n'
in_member_variable = False
continue
assert False, 'unreachable'
# record forward declarations
if stripped_line.startswith('// auto generated forward declare list'):
in_forward_declaration_list = True
if in_forward_declaration_list:
if (
stripped_line.startswith('class')
or stripped_line.startswith('struct')
or stripped_line.startswith('namespace')
):
forward_declarations.append(stripped_line)
if stripped_line.startswith('// clang-format on') and in_forward_declaration_list:
in_forward_declaration_list = False
# record namespace & classes
if not in_forward_declaration_list and args.fix_includes_for_member_variables:
if line.startswith('class') or line.startswith(
'struct'
): # ignore nested class, FIXME: ignore template class
founded = CppUtil.find_class_definition(line)
if founded:
include_fixer.record_class_definition(
path_to_file, '::'.join(current_namespace), founded
)
founded = CppUtil.find_namespace_declaration(line)
if founded:
current_namespace.append(founded)
if '// namespace' in stripped_line:
current_namespace.pop()
# remove useless thunks.
if '// NOLINTEND' in line and (in_useless_thunk or in_static_variable):
in_useless_thunk = False
in_static_variable = False
continue # don't add this line.
for token in RECORDED_THUNKS:
if token in line:
in_useless_thunk = True
is_modified = True
# remove previous access specifier.
content = content[: content.rfind('public:')]
content = content[: content.rfind('\n')] # for nested classes
if not in_useless_thunk:
content += line
if is_modified:
if args.fix_includes_for_member_variables and has_typed_storage:
include_fixer.add_pending_fix_queue(
path_to_file, forward_declarations, member_variable_types
)
with open(path_to_file, 'w', encoding='utf-8') as wfile:
wfile.write(content)
def iterate(args: Options):
assert os.path.isdir(args.base_dir)
include_fixer = args.fix_includes_for_member_variables and IncludeFixer()
for root, dirs, files in os.walk(args.base_dir):
for file in files:
if CppUtil.is_header_file(file):
process(os.path.join(root, file), include_fixer, args)
if include_fixer:
include_fixer.run_fix()

View File

@@ -0,0 +1,164 @@
import os
import re
import util.cpp_language as CppUtil
class IncludeFixer:
_need_fix_queue = dict()
_class_defs_record = dict()
def __init__(self):
pass
def record_class_definition(self, path: str, namespace: str, class_name: str):
assert len(path) > 0 and len(class_name) > 0
assert '::' not in class_name # c++ does not support forward declaration for nested class.
if namespace not in self._class_defs_record:
self._class_defs_record[namespace] = {}
assert class_name not in self._class_defs_record[namespace]
self._class_defs_record[namespace][class_name] = path[path.find('src/') + 4 :]
def add_pending_fix_queue(self, path: str, decls: list, member_typeset: list):
assert os.path.isfile(path)
assert path not in self._need_fix_queue
if len(decls) > 0 and len(member_typeset):
self._need_fix_queue[path] = [decls, member_typeset]
def _find_definition(self, decl: str, in_types: list):
find_decl = CppUtil.find_class_forward_declaration(decl)
assert find_decl
namespace = find_decl.namespace_decl
clazz = find_decl.class_decl
if not _is_full_type_needed(namespace, clazz, in_types):
return None
assert namespace in self._class_defs_record, f'namespace not recorded, {namespace}'
assert clazz in self._class_defs_record[namespace], f'{clazz} not recorded, in {namespace}'
return self._class_defs_record[namespace][find_decl.class_decl]
def run_fix(self):
for path, decl_and_types in self._need_fix_queue.items():
with open(path, 'r', encoding='utf-8') as file:
content = file.read()
for decl in decl_and_types[0]:
define_location = self._find_definition(decl, decl_and_types[1])
if define_location:
include = f'#include "{define_location}"'
content = content.replace(decl, include)
with open(path, 'w', encoding='utf-8') as wfile:
wfile.write(content)
def _is_full_type_needed(namespace_decl: str, class_decl: str, in_types: list):
# Y: T
# Y: std::optional<T>
# Y: std::variant<T>
# Y: std::array<T, _>
# Y: std::pair<T, T>
# Y: std::unordered_set<T>
# Y: std::unordered_map<T, _>
# Y: std::deque<T> // under msstl only
# Y: std::queue<T> // under msstl only
# N: T&
# N: T*
# N: std::map<T, T>
# N: std::shared_ptr<T>
# N: std::unique_ptr<T>
# N: std::weak_ptr<T>
# N: std::vector<T>
# N: std::set<T>
# N: std::unordered_map<_, T>
# N: std::function<T(T)>
def is_subtk_ends_with(full: str, tk: str, whats: list):
founded = False
for matched in re.finditer(rf'\b{re.escape(tk)}\b', full):
founded = True
if len(full) > matched.end():
for what in whats:
if full[matched.end() : matched.end() + len(what)] == what:
return founded, True
return founded, False
def find_template_name(full: str, what: str):
for matched in re.finditer(rf'\b{re.escape(what)}\b', full):
endpos = matched.start()
while True:
r_angle_bracket_pos = full.rfind('>', 0, endpos)
l_angle_bracket_pos = full.rfind('<', 0, endpos)
if l_angle_bracket_pos == -1:
return None
if r_angle_bracket_pos > l_angle_bracket_pos:
endpos = l_angle_bracket_pos
continue
ret = full[:l_angle_bracket_pos]
matched_non_name = list(re.finditer(r'[^a-zA-Z_]', ret))
if len(matched_non_name) > 0:
ret = ret[matched_non_name[-1].start() + 1 :]
assert len(ret) > 0
return ret
return None
for type_name in in_types:
founded, endswith = is_subtk_ends_with(
type_name, class_decl, ['&', '*', ' const&', ' const*']
)
# is not reference or pointer type
if founded and not endswith:
# is template parameter?
template_name = find_template_name(type_name, class_decl)
if template_name:
if template_name in [ # NOT Need full type.
'map',
'shared_ptr',
'unique_ptr',
'weak_ptr',
'vector',
'queue',
'set',
'function',
]:
pass # don't return false directly
elif template_name in [ # Need full type.
'optional',
'variant',
'array',
'pair',
'unordered_set',
'deque',
'queue',
]:
return True
elif template_name in [ # EMPTY TEMPLATE CLASS
'ScriptFilteredEventSignal',
'OwnerPtr',
'UniqueOwnerPointer',
'NotNullNonOwnerPtr',
'NonOwnerPointer',
'ServiceRegistrationToken',
'IDType',
'WeakRef',
'SubChunkStorage',
'ServiceReference',
'typeid_t',
'List',
'MemoryPool',
'ThreadOwner',
'StrongTypedObjectHandle',
'Promise',
'Factory',
'Publisher',
'ServiceRegistrationToken',
]:
pass
else:
return True # on default
else:
return True # not a template parameter
return False

View File

@@ -1,193 +1,6 @@
import os
import argparse
import re
class ProcessorOptions:
base_dir = str()
# functions
remove_constructor_thunk = bool()
remove_destructor_thunk = bool()
remove_virtual_function_thunk = bool()
# variables
remove_virtual_table_pointer_thunk = bool()
restore_static_variable = bool()
restore_member_variable = bool()
def __init__(self, args):
self.base_dir = args.path
self.remove_constructor_thunk = args.remove_constructor_thunk
self.remove_destructor_thunk = args.remove_destructor_thunk
self.remove_virtual_function_thunk = args.remove_virtual_function_thunk
self.remove_virtual_table_pointer_thunk = args.remove_virtual_table_pointer_thunk
self.restore_static_variable = args.restore_static_variable
self.restore_member_variable = args.restore_member_variable
if args.all:
self.set_all(True)
def set_function(self, opt: bool):
self.remove_constructor_thunk = opt
self.remove_destructor_thunk = opt
self.remove_virtual_function_thunk = opt
def set_variable(self, opt: bool):
self.remove_virtual_table_pointer_thunk = opt
self.restore_static_variable = opt
self.restore_member_variable = opt
def set_all(self, opt: bool):
self.set_function(opt)
self.set_variable(opt)
def process_headers(path_to_file: str, args: ProcessorOptions):
assert os.path.isfile(path_to_file)
RECORDED_THUNKS = []
if args.remove_constructor_thunk:
RECORDED_THUNKS.append('// constructor thunks')
if args.remove_destructor_thunk:
RECORDED_THUNKS.append('// destructor thunk')
if args.remove_virtual_table_pointer_thunk:
RECORDED_THUNKS.append('// vftables')
if args.remove_virtual_function_thunk:
RECORDED_THUNKS.append('// virtual function thunks')
with open(path_to_file, 'r', encoding='utf-8') as file:
# states
is_modified = False
in_useless_thunk = False
in_static_variable = False
in_member_variable = False
ll_typed_regex = re.compile(r'TypedStorage<(\d+), (\d+), (.*?)> (\w+);')
ll_untyped_regex = re.compile(r'UntypedStorage<(\d+), (\d+)> (\w+);')
def regex_preprocess_name(con: str):
# typed storage can be very complex, so we need to preprocess it.
# TODO: find a better way.
return re.sub(r'\s+', ' ', con).replace('< ', '<')
# tmp
content = ''
for line in file.readlines():
# restore static member variable thunk:
if args.restore_static_variable and '// static variables' in line:
in_static_variable = True
is_modified = True
if in_static_variable:
this_line = line.strip()
if this_line.endswith(';'):
if not this_line.startswith('MCAPI'): # declaration may not be on one line
begin_pos = content.rfind('MCAPI')
this_line = content[begin_pos:] + this_line
content = content[:begin_pos]
this_line = this_line.strip()
# remove parameter list (convert to static variable)
call_spec_pos = this_line.rfind('()')
assert call_spec_pos != -1
this_line = this_line[:call_spec_pos] + this_line[call_spec_pos + 2 :]
# remove reference
refsym_pos = this_line.rfind('&') # T&
tlpsym_pos = this_line.rfind('>') # ::std::add_lvalue_reference_t<T>
assert refsym_pos != -1 or tlpsym_pos != -1, f'in {path_to_file}'
if tlpsym_pos == -1 or refsym_pos > tlpsym_pos:
this_line = this_line[:refsym_pos] + this_line[refsym_pos + 1 :]
elif refsym_pos == -1 or tlpsym_pos > refsym_pos:
# C-style arrays must have '[]' written after the variable name
this_line = this_line[:tlpsym_pos] + '>' + this_line[tlpsym_pos:]
this_line = this_line.replace(
'::std::add_lvalue_reference_t<',
'::std::remove_reference_t<::std::add_lvalue_reference_t<',
)
content += f'\t{this_line}\n'
continue
if args.restore_member_variable and '::ll::' in line: # union { ... };
in_member_variable = True
is_modified = True
if in_member_variable:
# ::ll::TypedStorage<Alignment, Size, T> mVar;
# ::ll::UntypedStorage<Alignment, Size> mVar;
this_line = line.strip()
if this_line.endswith(';'):
if not this_line.startswith('::ll::'):
begin_pos = content.rfind('::ll::')
this_line = content[begin_pos:] + this_line
content = content[:begin_pos]
this_line = this_line.strip()
if 'TypedStorage' in this_line:
matched = ll_typed_regex.search(regex_preprocess_name(this_line))
assert matched and matched.lastindex == 4, (
f'in {path_to_file}, line="{this_line}"'
)
align = matched[1] # unused.
size = matched[2] # unused.
type_name = matched[3]
var_name = matched[4]
content += f'\t{type_name} {var_name};\n'
in_member_variable = False
continue
if 'UntypedStorage' in this_line:
matched = ll_untyped_regex.search(regex_preprocess_name(this_line))
assert matched and matched.lastindex == 3, (
f'in {path_to_file}, line="{this_line}"'
)
align = matched[1]
size = matched[2]
var_name = matched[3]
content += f'\talignas({align}) std::byte {var_name}[{size}];\n'
in_member_variable = False
continue
assert False, 'unreachable'
# remove useless thunks.
if '// NOLINTEND' in line and (in_useless_thunk or in_static_variable):
in_useless_thunk = False
in_static_variable = False
continue # don't add this line.
for token in RECORDED_THUNKS:
if token in line:
in_useless_thunk = True
is_modified = True
# remove previous access specifier.
content = content[: content.rfind('public:')]
content = content[: content.rfind('\n')] # for nested classes
if not in_useless_thunk:
content += line
if is_modified:
with open(path_to_file, 'w', encoding='utf-8') as wfile:
wfile.write(content)
def iterate_headers(args: ProcessorOptions):
assert os.path.isdir(args.base_dir)
def is_cxx_header(path: str):
return path.endswith('.h') or path.endswith('.hpp')
for root, dirs, files in os.walk(args.base_dir):
for file in files:
if is_cxx_header(file):
process_headers(os.path.join(root, file), args)
import header_processor as HeaderProcessor
def main():
@@ -206,7 +19,7 @@ def main():
args = parser.parse_args()
iterate_headers(ProcessorOptions(args))
HeaderProcessor.iterate(HeaderProcessor.Options(args))
print('done.')

View File

@@ -0,0 +1,13 @@
"""
Preprocessor for _HeaderOutputPredefine.h
"""
def process(path_to_file: str):
with open(path_to_file, 'r', encoding='utf-8') as file:
content = file.read()
content += '\n#include <winsock2.h>'
with open(path_to_file, 'w', encoding='utf-8') as wfile:
wfile.write(content)

View File

@@ -0,0 +1,87 @@
"""
C++ Language Utility
* some methods may not be designed to be universal.
"""
class ForwardDeclaration:
namespace_decl = str()
class_decl = str()
def __init__(self, namespace_decl: str, class_decl: str):
self.namespace_decl = namespace_decl
self.class_decl = class_decl
def is_header_file(path: str):
return path.endswith('.h') or path.endswith('.hpp')
def find_class_definition(line: str) -> str | None:
# class A (no quotation mark)
# class A :
# class A {
# class A { ... }; (in single line)
specifier_size = len('class')
class_pos = line.find('class ')
struct_pos = line.find('struct ')
assert class_pos == -1 or struct_pos == -1, f'line = {line}, c = {class_pos}, s = {struct_pos}'
left_brace_pos = line.find('{')
semicolon_pos = line.find(';')
if semicolon_pos != -1 and (left_brace_pos == -1 or semicolon_pos < left_brace_pos):
return None # is forward decl
if class_pos == -1:
if struct_pos == -1:
return None # is not class defs
specifier_size = len('struct')
class_pos = struct_pos
end_pos = len(line)
colon_pos = line.find(':')
if left_brace_pos != -1:
end_pos = min(end_pos, left_brace_pos)
if colon_pos != -1:
end_pos = min(end_pos, colon_pos)
return line[class_pos + specifier_size : end_pos].strip()
def find_class_forward_declaration(line: str) -> ForwardDeclaration | None:
# class A;
# namespace B { class A; }
namespace_decl = ''
class_decl = ''
namespace_pos = line.find('namespace')
left_brace_pos = line.find('{')
if namespace_pos != -1 and left_brace_pos != -1:
namespace_decl = line[namespace_pos + len('namespace') : left_brace_pos].strip()
specifier_size = len('class')
class_pos = line.find('class ')
struct_pos = line.find('struct ')
assert class_pos == -1 or struct_pos == -1
semicolon_pos = line.find(';')
if semicolon_pos == -1:
return None
if class_pos == -1:
if struct_pos == -1:
return None
specifier_size = len('struct')
class_pos = struct_pos
class_decl = line[class_pos + specifier_size : semicolon_pos].strip()
return ForwardDeclaration(namespace_decl, class_decl)
def find_namespace_declaration(line: str) -> str | None:
namespace_pos = line.find('namespace')
left_brace_pos = line.find('{')
if namespace_pos == -1 or left_brace_pos == -1:
return None
return line[namespace_pos + len('namespace') : left_brace_pos].strip()