import re from pathlib import Path def build_cpp_function_index(root_path): func_index = {} extensions = {'.cpp', '.cc', '.cxx', '.c', '.hpp', '.h'} pattern = re.compile( r'([\w:\s*<&>,\[\]\(\)]+?)' r'\s+' r'([a-zA-Z_][a-zA-Z0-9_:]*)' r'\s*\(', ) for file_path in Path(root_path).rglob('*'): if file_path.suffix.lower() not in extensions: continue if not file_path.is_file(): continue try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: content = f.read() except Exception as e: print(f'Failed to read file {file_path}: {e}') continue # Remove the compile directives and comments lines = content.split('\n') clean_lines = [line for line in lines if not line.strip().startswith(('#', '//'))] content = '\n'.join(clean_lines) for match in pattern.finditer(content): return_type_part = match.group(1).strip() full_func_name = match.group(2).strip() if not return_type_part or not re.match(r'^[a-zA-Z_]', return_type_part): continue first_token = return_type_part.split()[0] if first_token in {'return', 'if', 'else', 'for', 'while', 'switch', 'case', 'throw', 'catch', 'auto'}: continue # Extract base name if '::' in full_func_name: base_name = full_func_name.split('::')[-1] else: base_name = full_func_name if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', base_name): continue # Find matching ')' paren_start = match.end() - 1 paren_count = 0 pos = paren_start while pos < len(content): ch = content[pos] if ch == '(': paren_count += 1 elif ch == ')': paren_count -= 1 if paren_count == 0: break elif paren_count < 0: pos = -1 break pos += 1 else: continue if pos == -1: continue # Check context before match: should be at statement boundary match_start = match.start() context_before = content[max(0, match_start - 50):match_start] if context_before and re.search(r'[a-zA-Z0-9_]$', context_before.rstrip()): continue # Check for definition or header declaration is_header = file_path.suffix.lower() in {'.h', '.hpp', '.cuh'} after_paren = content[pos+1:pos+500] has_brace = '{' in after_paren has_semicolon = ';' in after_paren.split('{')[0] if has_brace or (is_header and has_semicolon): sig_start = match.start(1) full_signature = content[sig_start:pos+1].strip() if base_name not in func_index: func_index[base_name] = full_signature return func_index class BracketTracker: """ Tracks nesting levels of various brackets in C++ code: - () → paren - [] → bracket - {} → brace - <> → angle (treated as template brackets only at top level) Provides is_top_level() to check if currently outside all brackets. """ def __init__(self): self.paren = 0 # () self.bracket = 0 # [] self.brace = 0 # {} self.angle = 0 # <> def update(self, char: str): """ Update internal counters based on the given character. """ if char == '(': self.paren += 1 elif char == ')': self.paren -= 1 elif char == '[': self.bracket += 1 elif char == ']': self.bracket -= 1 elif char == '{': self.brace += 1 elif char == '}': self.brace -= 1 # Angle brackets < > are only treated as template delimiters # when not inside (), [], or {} elif char == '<' and self._in_top_level_of_other_brackets(): self.angle += 1 elif char == '>' and self.angle > 0 and self._in_top_level_of_other_brackets(): self.angle -= 1 def _in_top_level_of_other_brackets(self): """ Check if not inside parentheses, square brackets, or braces (for correct template bracket recognition). """ return self.paren == 0 and self.bracket == 0 and self.brace == 0 def is_top_level(self): """ Check if completely at top level (all bracket counters are zero). """ return (self.paren == 0 and self.bracket == 0 and self.brace == 0 and self.angle == 0) def extract_m_def_statements(root_path): """ Scan all c files under root_path and extract all m.def(...) statements. """ results = [] extensions = {'.hpp', '.cpp', '.h', '.cc'} # Regex: match m.def( ... ), supports multi-line pattern = re.compile(r'm\.def\s*\(') for file_path in Path(root_path).rglob('*'): if file_path.suffix.lower() not in extensions: continue if not file_path.is_file(): continue try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: content = f.read() except Exception as e: print(f'Failed to read file {file_path}: {e}') continue m_def_list = [] lines = content.splitlines(keepends=True) i = 0 while i < len(lines): line = lines[i] if 'm.def(' in line: # Found a potential starting line start_i = i # Check if it's a comment stripped = line.lstrip() if stripped.startswith('//') or stripped.startswith('/*'): i += 1 continue # Try to match the complete m.def(...) call paren_count = 0 j = i found_start = False while j < len(lines): current_line = lines[j] for k, char in enumerate(current_line): if char == '(': if not found_start and re.search(r'm\.def\s*\(', current_line[:k+1]): found_start = True if found_start: paren_count += 1 elif char == ')': if found_start: paren_count -= 1 if paren_count == 0: # Found complete statement full_stmt = ''.join(lines[i:j+1]).rstrip() m_def_list.append(full_stmt) i = j break if paren_count <= 0 and found_start: break j += 1 else: pass i += 1 if m_def_list: results.append({ 'file': str(file_path), 'm_def_statements': m_def_list }) return results def parse_m_def_statement(m_def_str): result = { 'python_function_name': None, 'num_args': 0, 'default_args': {}, 'is_lambda': False, } # Extract top-level arguments start = m_def_str.find('m.def(') if start == -1: raise ValueError(f'[{m_def_str}] Could not find m.def start position') paren_count = 0 content_start = start + len('m.def(') content_end = -1 for i in range(content_start, len(m_def_str)): ch = m_def_str[i] if ch == '(': paren_count += 1 elif ch == ')': if paren_count == 0: content_end = i break else: paren_count -= 1 if content_end == -1: raise ValueError(f'[{m_def_str}] m.def parentheses not closed') args_content = m_def_str[content_start:content_end] # Split arguments using BracketTracker args_list = [] current = [] tracker = BracketTracker() for ch in args_content: if ch in '()[]{}<>': tracker.update(ch) if ch == ',' and tracker.is_top_level(): args_list.append(''.join(current).strip()) current = [] else: current.append(ch) if current: args_list.append(''.join(current).strip()) if len(args_list) < 2: raise ValueError(f'[{m_def_str}] m.def has insufficient arguments') # Extract Python function name first = args_list[0].strip() str_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"', first) if str_match: result['python_function_name'] = str_match.group(1) else: raise ValueError(f'[{m_def_str}] m.def first argument should be a string literal') cpp_func_part = args_list[1].strip() if cpp_func_part.startswith('&'): cpp_func_part = cpp_func_part[1:].strip() if cpp_func_part.startswith('['): result['is_lambda'] = True result['cpp_function_name'] = None else: if '::' in cpp_func_part: cpp_func_name = cpp_func_part.split('::')[-1] else: cpp_func_name = cpp_func_part match = re.match(r'^([a-zA-Z_][a-zA-Z0-9_]*)', cpp_func_name) if match: result['cpp_function_name'] = match.group(1) else: result['cpp_function_name'] = cpp_func_name # Parse py::arg arguments py_args = args_list[2:] result['num_args'] = len(py_args) for idx, arg_expr in enumerate(py_args): expr = arg_expr.strip() # Find top-level '=' eq_pos = -1 p_depth = b_depth = br_depth = angle_depth = 0 i = 0 while i < len(expr): ch = expr[i] if ch == '(': p_depth += 1 elif ch == ')': p_depth -= 1 elif ch == '[': b_depth += 1 elif ch == ']': b_depth -= 1 elif ch == '{': br_depth += 1 elif ch == '}': br_depth -= 1 elif ch == '<' and p_depth == 0 and b_depth == 0 and br_depth == 0: angle_depth += 1 elif ch == '>' and angle_depth > 0 and p_depth == 0 and b_depth == 0 and br_depth == 0: angle_depth -= 1 elif ch == '=' and all(d == 0 for d in [p_depth, b_depth, br_depth, angle_depth]): eq_pos = i break i += 1 if eq_pos != -1: default_val = expr[eq_pos + 1:].strip() if not default_val: raise ValueError(f'[{expr}] Default value is empty (arg {idx})') result['default_args'][idx] = default_val return result def extract_cpp_signature_from_content(cpp_func_name, content): """ Search for the C++ function signature of cpp_func_name in the given file content. """ if not cpp_func_name: return None # Build regex: match function starting with cpp_func_name (after word boundary) # Note: function name may be preceded by return type (with templates, namespaces, etc.), followed by '(' pattern = re.compile( r'^\s*' # leading whitespace r'([\w:\s*<&>,\[\]\(\)]+?)' # return type (non-greedy, allows templates, pointers, etc.) r'\s+' # at least one space r'\b' + re.escape(cpp_func_name) + r'\b' # function name (word boundary) r'\s*\(', # optional whitespace + start of param list re.MULTILINE ) for match in pattern.finditer(content): # Find '(' position after function name paren_start = match.end() - 1 if content[paren_start] != '(': paren_start = content.find('(', match.end(0) - 1) if paren_start == -1: continue # From '(', match to corresponding ')' paren_count = 0 pos = paren_start while pos < len(content): ch = content[pos] if ch == '(': paren_count += 1 elif ch == ')': paren_count -= 1 if paren_count == 0: start_sig = match.start(1) full_signature = content[start_sig:pos+1].strip() return full_signature pos += 1 return None def parse_mdef_and_attach_cpp_signatures(item, func_index): """ Enhance item by parsing m.def and extracting C++ function signature from global index """ statements_with_parsed_signatures = [] for stmt in item['m_def_statements']: parsed = parse_m_def_statement(stmt,) cpp_func_name = parsed.get('cpp_function_name') cpp_sig = None if cpp_func_name and cpp_func_name in func_index: cpp_sig = func_index[cpp_func_name] else: if not parsed['is_lambda']: print(f'Warning: C++ function "{cpp_func_name}" not found in any .cpp file') parsed['cpp_signature'] = cpp_sig statements_with_parsed_signatures.append({ 'raw': stmt, 'parsed': parsed }) return { 'm_def_statements': statements_with_parsed_signatures } def parse_cpp_signature(cpp_sig): """ Parse a C++ function signature and extract return type, parameter types, and names. """ if not cpp_sig or not cpp_sig.strip(): return None # Find function name: last identifier before '(' paren_pos = cpp_sig.find('(') if paren_pos == -1: return None before_paren = cpp_sig[:paren_pos].strip() if not before_paren: return None # Function name is the last word in before_paren (may include templates like func) tokens = before_paren.split() if len(tokens) < 2: return None # Heuristic: function name is usually the last token (may include <>) func_name_part = tokens[-1] return_type = ' '.join(tokens[:-1]).strip() # Now extract parameter list content param_list_str = cpp_sig[paren_pos+1:cpp_sig.rfind(')')].strip() parameters = [] if param_list_str and param_list_str != 'void': # 'void' means no parameters # Split parameters (handle commas not inside templates/brackets) param_decls = split_cpp_parameters(param_list_str) for decl in param_decls: decl = decl.strip() if not decl: continue # Try to split type and name from right to left param_info = parse_parameter_declaration(decl) if param_info: parameters.append(param_info) return { 'return_type': return_type, 'parameters': parameters, 'num_parameters': len(parameters) } def split_cpp_parameters(param_str: str): """ Split a C++ parameter list string by top-level commas, e.g., 'int a, std::vector b' → ['int a', 'std::vector b'] """ if not param_str.strip() or param_str == 'void': return [] params = [] current = [] tracker = BracketTracker() for ch in param_str: if ch in '()[]{}<>': tracker.update(ch) if ch == ',' and tracker.is_top_level(): param = ''.join(current).strip() if param: # Only add non-empty parameters params.append(param) current = [] else: current.append(ch) if current: final_param = ''.join(current).strip() if final_param: # Only add non-empty parameters params.append(final_param) return params def parse_parameter_declaration(decl: str): """ Parse a single parameter declaration, e.g., 'const std::string& name' → {'type': 'const std::string&', 'name': 'name'} Improved version that better handles template types. """ decl = decl.strip() if not decl: return None # Remove possible default value (starting from top-level '=') tracker = BracketTracker() eq_pos = -1 for i, ch in enumerate(decl): if ch in '()[]{}<>': tracker.update(ch) elif ch == '=' and tracker.is_top_level(): eq_pos = i break if eq_pos != -1: decl = decl[:eq_pos].strip() # Now decl is 'type name' or just 'type' # Instead of simple splitting, we'll use a more robust approach # to find the parameter name # First, let's handle the case where there's no explicit parameter name # (this sometimes happens in function declarations) if not re.search(r'[a-zA-Z_][a-zA-Z0-9_]*$', decl): # No parameter name found, just return the type return { 'type': decl, 'name': None } # Use bracket tracking to find where the type ends and name begins tracker = BracketTracker() name_start = -1 # Scan from the end to find the start of the parameter name # We look for the first identifier that's outside all brackets i = len(decl) - 1 while i >= 0: ch = decl[i] if ch in '()[]{}<>': tracker.update(ch) # If we're at top level and find an identifier character if tracker.is_top_level() and re.match(r'[a-zA-Z0-9_]', ch): # Track back to find the start of this identifier name_start = i while name_start > 0 and re.match(r'[a-zA-Z0-9_]', decl[name_start - 1]): name_start -= 1 # Check if this might be part of a type keyword (like 'int', 'bool', etc.) potential_name = decl[name_start:i+1] type_keywords = {'int', 'long', 'short', 'char', 'bool', 'float', 'double', 'void', 'auto', 'const', 'static', 'volatile', 'mutable', 'unsigned', 'signed'} # If it's not a type keyword and looks like a parameter name, use it if (potential_name not in type_keywords and re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', potential_name)): break i -= 1 if name_start != -1 and i >= 0: param_name = decl[name_start:i+1] param_type = decl[:name_start].strip() # Clean up the type - remove trailing &, * and whitespace param_type = param_type.rstrip('&* \t') return { 'type': param_type, 'name': param_name } # Fallback: if we can't find a clear parameter name, just return the type return { 'type': decl, 'name': None } def extract_cpp_signature_details(item): """ For each m.def entry in item, parse cpp_signature to extract return type and parameter details. """ statements_with_parsed_signatures = [] for stmt_info in item['m_def_statements']: parsed = stmt_info['parsed'] cpp_sig = parsed.get('cpp_signature') cpp_params_info = None if cpp_sig: try: cpp_params_info = parse_cpp_signature(cpp_sig) except Exception as e: print(f'Failed to parse C++ signature: {e}') parsed['cpp_parsed_signature'] = cpp_params_info statements_with_parsed_signatures.append({ 'raw': stmt_info['raw'], 'parsed': parsed }) return { 'm_def_statements': statements_with_parsed_signatures } def cpp_type_to_python_type(cpp_type: str) -> str: if not cpp_type: return 'Any' original = cpp_type.strip() if not original: return 'Any' # Remove C++ specifiers that don't affect Python type cleaned = re.sub(r'\b(static|inline|constexpr|thread_local|extern|mutable|const|volatile|endif)\b', '', original) cleaned = cleaned.replace('&', '').replace('*', '').strip() cleaned = re.sub(r'\s+', ' ', cleaned).strip() # Handle void if cleaned == 'void': return 'None' # Handle template types — ORDER MATTERS! Must come before internal type checks. # std::pair if cleaned.startswith('std::pair<'): inner = cleaned[10:-1].strip() # len('std::pair<') == 10 args = split_template_args(inner) if len(args) == 2: t1 = cpp_type_to_python_type(args[0]) t2 = cpp_type_to_python_type(args[1]) return f'tuple[{t1}, {t2}]' else: print(f'Warning: std::pair with unexpected number of args: {cleaned}') return 'Any' # std::tuple if cleaned.startswith('std::tuple<'): inner = cleaned[11:-1].strip() # len('std::tuple<') == 11 args = split_template_args(inner) py_types = [cpp_type_to_python_type(arg) for arg in args] return f"tuple[{', '.join(py_types)}]" # std::vector if cleaned.startswith('std::vector<'): inner = cleaned[12:-1].strip() # len('std::vector<') == 12 args = split_template_args(inner) if len(args) == 1: inner_py = cpp_type_to_python_type(args[0]) return f'list[{inner_py}]' else: print(f'Warning: std::vector with unexpected args: {cleaned}') return 'Any' # std::optional if cleaned.startswith('std::optional<'): inner = cleaned[14:-1].strip() # len('std::optional<') == 14 args = split_template_args(inner) if len(args) == 1: inner_py = cpp_type_to_python_type(args[0]) return f'Optional[{inner_py}]' else: print(f'Warning: std::optional with unexpected args: {cleaned}') return 'Any' # std::string if re.search(r'\bstd::string\b', original): return 'str' # C-style strings: char*, const char*, char[], etc. if re.search(r'\b(?:const\s+)?char\s*[\*\[]', original): return 'str' # Boolean if re.search(r'\bbool\b', cleaned): return 'bool' # Integer types (including fixed-width and common aliases) if re.search(r'\b(int|long|short|size_t|ssize_t|ptrdiff_t|' r'int8_t|int16_t|int32_t|int64_t|' r'uint8_t|uint16_t|uint32_t|uint64_t)\b', cleaned): return 'int' # Floating-point if re.search(r'\b(float|double|long\s+double)\b', cleaned): return 'float' # torch::Tensor if re.search(r'\btorch::Tensor\b', original): return 'torch.Tensor' # Unrecognized type print(f'Warning: Unrecognized C++ type: {original}') return 'Any' def split_template_args(template_args: str): """ Split template arguments, e.g., 'int, std::vector' → ['int', 'std::vector'] """ if not template_args.strip(): return [] args = [] current = [] tracker = BracketTracker() for ch in template_args: if ch in '()[]{}<>': tracker.update(ch) if ch == ',' and tracker.is_top_level(): args.append(''.join(current).strip()) current = [] else: current.append(ch) if current: args.append(''.join(current).strip()) return args def cpp_default_to_python_default(cpp_default: str): """ Convert C++ default value string to valid Python expression string. """ if not cpp_default: return 'None' s = cpp_default.strip() # Handle string literals: 'bf16' → 'bf16' # Match: starts and ends with unescaped double quotes string_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"$', s) if string_match: return s # Handle boolean literals if s == 'false': return 'False' if s == 'true': return 'True' # Handle null-like values: nullptr, nullopt, NULL, etc. if s in ('nullptr', 'NULL') or 'nullopt' in s: return 'None' # Handle std::tuple({128, 128}) → (128, 128) tuple_match = re.match(r'std::tuple\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) if tuple_match: inner = tuple_match.group(1) # {128, 128} inner_py = inner.replace('{', '(').replace('}', ')') return inner_py # Handle std::make_tuple(1, 2, 3) → (1, 2, 3) make_tuple_match = re.match(r'std::make_tuple\s*\(\s*(.*?)\s*\)', s) if make_tuple_match: inner = make_tuple_match.group(1) # Ensure it's a valid tuple even with one element: add comma if needed? # But in C++ default args, it's usually multi-element, so we assume valid. return f'({inner})' # Handle std::vector({1,2,3}) → [1, 2, 3] vector_match = re.match(r'std::vector\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) if vector_match: inner = vector_match.group(1) inner_py = inner.replace('{', '[').replace('}', ']') return inner_py # Handle numeric literals: integers and floats if re.match(r'^[+-]?\d+$', s): # integer return s if re.match(r'^[+-]?\d*\.\d+([eE][+-]?\d+)?$', s): # float return s # Fallback: unrecognized → warn and return None print(f'Warning: Unrecognized default value: {s}') return 'None' def generate_pyi_function(item_entry): parsed = item_entry['parsed'] py_name = parsed['python_function_name'] if parsed.get('is_lambda'): return f'def {py_name}(*args, **kwargs) -> Any: ...' sig_info = parsed.get('cpp_parsed_signature') default_args = parsed.get('default_args', {}) if not sig_info: return f'def {py_name}(*args, **kwargs) -> Any: ...' return_type = cpp_type_to_python_type(sig_info['return_type']) params = sig_info['parameters'] num_params = len(params) # Build parameter list param_lines = [] for i in range(num_params): param_info = params[i] if i < len(params) else {'type': 'Any', 'name': f'arg{i}'} param_type = cpp_type_to_python_type(param_info['type']) param_name = param_info['name'] or f'arg{i}' # Replace invalid Python identifiers (e.g., keywords) if param_name in {'def', 'class', 'from', 'import', 'None', 'True', 'False'}: param_name = f'{param_name}_' # Check for default value if i in default_args: cpp_default = default_args[i] py_default = cpp_default_to_python_default(cpp_default) param_str = f' {param_name}: {param_type} = {py_default}' else: param_str = f' {param_name}: {param_type}' param_lines.append(param_str) if param_lines: params_block = ',\n'.join(param_lines) func_def = f'def {py_name}(\n{params_block}\n) -> {return_type}: ...' else: func_def = f'def {py_name}() -> {return_type}: ...' return func_def def generate_pyi_file_content(enhanced_results, module_name: str = 'my_module'): function_decls = [] has_optional = False has_torch = False has_numpy = False for item in enhanced_results: for stmt in item['m_def_statements']: try: decl = generate_pyi_function(stmt) function_decls.append(decl) if 'Optional[' in decl: has_optional = True if 'torch.Tensor' in decl: has_torch = True if 'numpy.ndarray' in decl or 'py::array' in str(stmt): has_numpy = True except Exception as e: func_name = stmt['parsed'].get('python_function_name', 'unknown') function_decls.append(f'# ERROR: failed to generate stub for {func_name}: {e}') imports = ['from typing import Any'] if has_optional: imports[0] += ', Optional' if has_torch: imports.append('import torch') if has_numpy: imports.append('import numpy') lines = [f'# Stubs for module: {module_name}', ''] lines.extend(imports) lines.append('') lines.append('') for decl in function_decls: lines.append(decl) lines.append('') lines.append('') return '\n'.join(lines) def generate_pyi_file(name, root, output_dir='.'): func_index = build_cpp_function_index(root) results = extract_m_def_statements(root) cpp_results = [] for item in results: enhanced_item = parse_mdef_and_attach_cpp_signatures(item, func_index) cpp_item = extract_cpp_signature_details(enhanced_item) cpp_results.append(cpp_item) pyi_content = generate_pyi_file_content(cpp_results, module_name=name) output_path = Path(output_dir) / f'{name}.pyi' output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.write(pyi_content) print(f'.pyi file generated: {output_path}')