# -*- coding: utf-8 -*- # # Copyright Grégory Soutadé # This file is part of SOAdvancedDissector # SOAdvancedDissector is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # SOAdvancedDissector is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with SOAdvancedDissector. If not, see . # import sys from cppprototypeparser import CPPPrototypeParser print_raw_virtual_table = False print_indent = False cur_indent = 0 parser = CPPPrototypeParser() def setPrintRawVirtualTable(value): global print_raw_virtual_table print_raw_virtual_table = value def setPrintIndent(value): global print_indent print_indent = value def _getIndent(): global print_indent indent = '' if print_indent: indent = ' '*cur_indent return indent class Object: """Abstract object representation (simply a name)""" def __init__(self, name): self.name = name def find(self, obj): if self == obj: return self return None def getParametersDependencies(self): return None def getDependencies(self): """Get dependencies from other namespaces""" return None def __eq__(self, other): if type(other) == str: return self.name == other else: return self.name == other.name def __lt__(self, other): return self.name < other.name class Attribute(Object): """Class attribute (member)""" def __init__(self, name, address=0, namespace=''): Object.__init__(self, name) self.address = address self.namespace = namespace def fullname(self): """Return namespace::name""" if self.namespace: return '{}::{}'.format(self.namespace, self.name) else: return self.name def __eq__(self, other): if type(other) == str: return self.name == other else: if hasattr(other, 'address'): return self.address == other.address else: return self.name == other.name def __str__(self): return '{}void* {};\n'.format(_getIndent(), self.name) class Function(Attribute): """Function description""" def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''): # Be sure we have () in name if '(' in name and not name.endswith(')') and not name.endswith('const')\ and not name.endswith('const&'): name += ')' Attribute.__init__(self, name, address, namespace) self.virtual = virtual self.pure_virtual = pure_virtual self.constructor = False def isPure(self): """Is function pure virtual""" return self.pure_virtual def setConstructor(self, value): """Set/clear constructor property""" self.constructor = value def getDependencies(self): """Get dependencies from other namespaces""" dependencies = [] parser.parse(self.name) if not parser.has_parameters: return None for p in parser.parameters: # If parameter has namespace, add into dependencies if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]): dep = '::'.join(p) if not '::' in dep: continue dep = dep.replace('*', '') dep = dep.replace('&', '') if dep.endswith(' const'): dep = dep[:-len(' const')] dependencies.append(dep) if dependencies: return list(set(dependencies)) return None def _getReturnType(self): """Get method return type""" type_ = 'void ' if self.name.startswith('vtable_index'): type_='void* ' elif self.name.startswith('operator new'): type_='void* ' elif self.constructor: type_ = '' return type_ def __str__(self): res = '' type_ = self._getReturnType() if self.pure_virtual: res = 'virtual {}{} = 0;\n'.format(type_, self.name) elif self.virtual: res = 'virtual {}{};\n'.format(type_, self.name) else: res = '{}{};\n'.format(type_, self.name) res = '{}{}'.format(_getIndent(), res) return res class Namespace(Object): """Namespace description""" def __init__(self, name): Object.__init__(self, name) self.childs = [] self.dependencies = [] # Dependencies from objects in other namespace def addChild(self, child): """Add child (function, class, attribute) to namespace""" if not child in self.childs: self.childs.append(child) def removeChild(self, child): """Remove child from namespace""" self.childs.remove(child) def child(self, name): """Try to find name in childs without recursion""" for child in self.childs: if child.name == name: return child return None def find(self, obj): """Try to find obj in childs and their own child (with recursion)""" if self == obj: return self for child in self.childs: res = child.find(obj) if res: return res return None def fillFrom(self, other): """Copy all childs from other object""" for child in other.childs: self.childs.append(child) def getDependencies(self): """Get dependencies from other namespaces""" dependencies = [] for child in self.childs: depend = child.getDependencies() if depend: for d in depend: if not d.startswith('{}::'.format(self.name)): dependencies.append(d) if dependencies: return list(set(dependencies)) return [] def __str__(self): global cur_indent if self.name != 'global': res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name) cur_indent += 1 else: res = '' namespaces = [] classes = [] functions = [] other = [] for child in self.childs: if type(child) == Namespace: namespaces.append(child) elif type(child) == Class: classes.append(child) elif type(child) == Function: functions.append(child) else: other.append(child) # Compute classes inheritance dependency classes_dep = [] for class_ in sorted(classes): isDep = False for pos, class2 in enumerate(classes_dep): if class_ in class2.inherit_from: isDep = True classes_dep.insert(pos, class_) break if not isDep: classes_dep.append(class_) # Add class declaration if len(classes_dep) > 1: for c in classes_dep: res += '{}class {};\n'.format(_getIndent(), c.name) if classes_dep: res += '\n\n' for namespace in sorted(namespaces): res += namespace.__str__() if namespaces: res += '\n' for c in classes_dep: res += c.__str__() if classes_dep: res += '\n' for func in sorted(functions): res += func.__str__() if functions: res += '\n' for child in sorted(other): res += child.__str__() if other: res += '\n' if self.name != 'global': cur_indent -= 1 res += '{}}}\n'.format(_getIndent()) res += '\n' return res class Class(Namespace): """Class description""" def __init__(self, name, namespace=''): Namespace.__init__(self, name) self.constructors = [] self.destructors = [] self.inherit_from = [] self.virtual_functions = [] self.namespace = namespace def fullname(self): """Return namespace::name""" if self.namespace: return '{}::{}'.format(self.namespace, self.name) else: return self.name def _checkConstructor(self, obj): """Check if obj is a constructor/destructor and set its property. Returns ------- list Adequat list (constructor or destructor) or None """ if type(obj) != Function: return None if obj.name.startswith('~'): # sys.stderr.write('Check C {} -> D\n'.format(obj.name)) obj.setConstructor(True) return self.destructors if obj.name.startswith('{}('.format(self.name)): # sys.stderr.write('Check C {} -> C\n'.format(obj.name)) obj.setConstructor(True) return self.constructors # sys.stderr.write('Check C {} -> N\n'.format(obj.name)) return None def addVirtualFunction(self, obj): """Add a new virtual function""" if obj.address == 0 or not obj in self.virtual_functions: self._checkConstructor(obj) self.virtual_functions.append(obj) def updateVirtualFunction(self, idx, obj): """Update virtual function at index idx""" try: self._checkConstructor(obj) self.virtual_functions[idx] = obj except: sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj)) sys.stderr.flush() def addMember(self, obj): """Add a new member""" if obj in self.virtual_functions or\ obj in self.constructors or\ obj in self.destructors or\ obj in self.childs: return targetList = self._checkConstructor(obj) if targetList is None: self.childs.append(obj) else: targetList.append(obj) def addChild(self, child): return self.addMember(child) def addBaseClass(self, obj): self.inherit_from.append(obj) def fixVirtualFunction(self, index, newName): """Set real name for virtfunc and unknown virtfunc generic names """ if index >= len(self.virtual_functions): sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname())) return False virtfunc = self.virtual_functions[index] if not virtfunc.isPure(): return False if virtfunc.name.startswith('virtfunc') or\ virtfunc.name.startswith('unknown_virtfunc'): virtfunc.name = newName self._checkConstructor(virtfunc) return True return False def hasMultipleVirtualSections(self): """Check if we have a vtable_indexX (X>0) entry in our virtual functions table """ if not self.virtual_functions: return False for vfunc in self.virtual_functions: if vfunc.name.startswith('vtable_index') and\ vfunc.name != 'vtable_index0': return True return False def fixupInheritance(self): """Report virtual function name in base class if there are pure virtual. """ if not self.inherit_from: return # First, handle all of our bases for base in self.inherit_from: base.fixupInheritance() updated = False curIdx = 0 if self.hasMultipleVirtualSections(): for base in self.inherit_from: # First is vtable_index, skip it curIdx += 1 targetIdx = 1 updated = False while curIdx < len(base.virtual_functions): vfunc = self.virtual_functions[curIdx] if vfunc.name.startswith('virtual_index'): break if base.fixVirtualFunction(targetIdx, vfunc.name): updated = True curIdx += 1 targetIdx += 1 if updated: base.fixupInheritance() # Go to next vtable index if we are not while curIdx < len(self.virtual_functions): vfunc = self.virtual_functions[curIdx] if vfunc.name.startswith('virtual_index'): break curIdx += 1 else: for base in self.inherit_from: targetIdx = 0 while targetIdx < len(base.virtual_functions): vfunc = self.virtual_functions[curIdx] if base.fixVirtualFunction(targetIdx, vfunc.name): updated = True curIdx += 1 targetIdx += 1 if updated: base.fixupInheritance() def looksLikeNamespace(self): """Empty specific class attributes looks like a namespace """ return not self.constructors and\ not self.destructors and\ not self.inherit_from and\ not self.virtual_functions def _getDependencies(self, targetList, dependencies): for obj in targetList: res = obj.getDependencies() if res: for d in res: if self.namespace and not d.startswith(self.namespace): dependencies.append(d) def getDependencies(self): """Get dependencies from other namespaces""" dependencies = [] for base in self.inherit_from: if base.namespace: dependencies.append(base.fullname()) self._getDependencies(self.constructors, dependencies) self._getDependencies(self.destructors, dependencies) self._getDependencies(self.virtual_functions, dependencies) self._getDependencies(self.childs, dependencies) if dependencies: return list(set(dependencies)) return None def printOverloadedVirtualTable(self): """Only select overloaded methods from virtual table """ res = '' vfunc_res = '' for vfunc in self.virtual_functions: if vfunc.name.startswith('vtable_index') or\ vfunc.name.startswith('typeinfo()'): continue # Not overloaded by us if vfunc.namespace and vfunc.namespace != self.namespace: continue vfunc_res = vfunc.__str__() # Method already in result, # It's the case for virtual descriptor if vfunc_res in res: continue res += vfunc_res return res def __str__(self): global print_raw_virtual_table global cur_indent res = '{}class {}'.format(_getIndent(), self.name) if self.inherit_from: res += ': ' bases = [] for base in self.inherit_from: bases.append('public {}'.format(base.fullname())) res += ', '.join(bases) res += '\n{}{{\n'.format(_getIndent()) res += '{}public:\n'.format(_getIndent()) cur_indent += 1 for constructor in sorted(self.constructors): res += constructor.__str__() if len(self.constructors): res += '\n' for destructor in sorted(self.destructors): res += destructor.__str__() if len(self.destructors): res += '\n' # Do not sort virtual tables ! virtfuncs = '' if print_raw_virtual_table: for virtfunc in self.virtual_functions: virtfuncs += virtfunc.__str__() else: virtfuncs = self.printOverloadedVirtualTable() if len(virtfuncs): res += '{}\n'.format(virtfuncs) methods = [] other = [] for child in self.childs: if type(child) == Function: methods.append(child) else: other.append(child) for method in sorted(methods): res += method.__str__() if methods: res += '\n' for child in sorted(other): res += child.__str__() if other: res += '\n' cur_indent -= 1 res += '{}}};\n\n'.format(_getIndent()) return res