537 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			537 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | |
| #
 | |
| # Copyright Grégory Soutadé
 | |
| 
 | |
| # This file is part of SOAdvancedDissector
 | |
| 
 | |
| # SOAdvancedDissector is free software: you can redistribute it and/or modify
 | |
| # it under the terms of the GNU General Public License as published by
 | |
| # the Free Software Foundation, either version 3 of the License, or
 | |
| # (at your option) any later version.
 | |
| #
 | |
| # SOAdvancedDissector is distributed in the hope that it will be useful,
 | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
| # GNU General Public License for more details.
 | |
| #
 | |
| # You should have received a copy of the GNU General Public License
 | |
| # along with SOAdvancedDissector.  If not, see <http://www.gnu.org/licenses/>.
 | |
| #
 | |
| 
 | |
| import sys
 | |
| from cppprototypeparser import CPPPrototypeParser
 | |
| 
 | |
| print_raw_virtual_table = False
 | |
| print_indent = False
 | |
| cur_indent = 0
 | |
| 
 | |
| parser = CPPPrototypeParser()
 | |
| 
 | |
| def setPrintRawVirtualTable(value):
 | |
|     global print_raw_virtual_table
 | |
|     print_raw_virtual_table = value
 | |
| 
 | |
| def setPrintIndent(value):
 | |
|     global print_indent
 | |
|     print_indent = value
 | |
| 
 | |
| def _getIndent():
 | |
|     global print_indent
 | |
|     indent = ''
 | |
|     if print_indent:
 | |
|         indent = '    '*cur_indent
 | |
|     return indent
 | |
|     
 | |
| class Object:   
 | |
|     """Abstract object representation (simply a name)"""
 | |
|     def __init__(self, name):
 | |
|         self.name = name
 | |
|         
 | |
|     def find(self, obj):
 | |
|         if self == obj:
 | |
|             return self
 | |
|         return None
 | |
| 
 | |
|     def getParametersDependencies(self):
 | |
|         return None
 | |
|     
 | |
|     def getDependencies(self):
 | |
|         """Get dependencies from other namespaces"""
 | |
|         return None
 | |
|     
 | |
|     def __eq__(self, other):
 | |
|         if type(other) == str:
 | |
|             return self.name == other
 | |
|         else:
 | |
|             return self.name == other.name
 | |
| 
 | |
|     def __lt__(self, other):
 | |
|         return self.name < other.name
 | |
| 
 | |
| class Attribute(Object):
 | |
|     """Class attribute (member)"""
 | |
|     def __init__(self, name, address=0, namespace=''):
 | |
|         Object.__init__(self, name)
 | |
|         self.address = address
 | |
|         self.namespace = namespace
 | |
|     
 | |
|     def fullname(self):
 | |
|         """Return namespace::name"""
 | |
|         if self.namespace:
 | |
|             return '{}::{}'.format(self.namespace, self.name)
 | |
|         else:
 | |
|             return self.name
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         if type(other) == str:
 | |
|             return self.name == other
 | |
|         else:
 | |
|             if hasattr(other, 'address'):
 | |
|                 return self.address == other.address
 | |
|             else:
 | |
|                 return self.name == other.name
 | |
| 
 | |
|     def __str__(self):
 | |
|         return '{}void* {};\n'.format(_getIndent(), self.name)
 | |
| 
 | |
| 
 | |
| class Function(Attribute):
 | |
|     """Function description"""
 | |
|     def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''):
 | |
|         # Be sure we have () in name
 | |
|         if '(' in name and not name.endswith(')') and not name.endswith('const')\
 | |
|            and not name.endswith('const&'):
 | |
|             name += ')'
 | |
|         Attribute.__init__(self, name, address, namespace)
 | |
|         self.virtual = virtual
 | |
|         self.pure_virtual = pure_virtual
 | |
|         self.constructor = False
 | |
|         
 | |
|     def isPure(self):
 | |
|         """Is function pure virtual"""
 | |
|         return self.pure_virtual
 | |
| 
 | |
|     def setConstructor(self, value):
 | |
|         """Set/clear constructor property"""
 | |
|         self.constructor = value
 | |
|         
 | |
|     def getDependencies(self):
 | |
|         """Get dependencies from other namespaces"""
 | |
|         dependencies = []
 | |
|         parser.parse(self.name)
 | |
|         if not parser.has_parameters:
 | |
|             return None
 | |
|         for p in parser.parameters:
 | |
|             # If parameter has namespace, add into dependencies
 | |
|             if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]):
 | |
|                 dep = '::'.join(p)
 | |
|                 if not '::' in dep: continue
 | |
|                 dep = dep.replace('*', '')
 | |
|                 dep = dep.replace('&', '')
 | |
|                 if dep.endswith(' const'):
 | |
|                     dep = dep[:-len(' const')]
 | |
|                 dependencies.append(dep)
 | |
|         if dependencies:
 | |
|             return list(set(dependencies))
 | |
|         return None
 | |
| 
 | |
|     def _getReturnType(self):
 | |
|         """Get method return type"""
 | |
|         type_ = 'void '
 | |
|         if self.name.startswith('vtable_index'): type_='void* '
 | |
|         elif self.name.startswith('operator new'): type_='void* '
 | |
|         elif self.constructor: type_ = ''
 | |
|         return type_
 | |
|     
 | |
|     def __str__(self):
 | |
|         res = ''
 | |
|         type_ = self._getReturnType()
 | |
|         if self.pure_virtual:
 | |
|             res = 'virtual {}{} = 0;\n'.format(type_, self.name)
 | |
|         elif self.virtual:
 | |
|             res = 'virtual {}{};\n'.format(type_, self.name)
 | |
|         else:
 | |
|             res = '{}{};\n'.format(type_, self.name)
 | |
|         res = '{}{}'.format(_getIndent(), res)
 | |
|         return res
 | |
| 
 | |
| class Namespace(Object):
 | |
|     """Namespace description"""
 | |
|     def __init__(self, name):
 | |
|         Object.__init__(self, name)
 | |
|         self.childs = []
 | |
|         self.dependencies = [] # Dependencies from objects in other namespace
 | |
|         
 | |
|     def addChild(self, child):
 | |
|         """Add child (function, class, attribute) to namespace"""
 | |
|         if not child in self.childs:
 | |
|             self.childs.append(child)
 | |
| 
 | |
|     def removeChild(self, child):
 | |
|         """Remove child from namespace"""
 | |
|         self.childs.remove(child)
 | |
|         
 | |
|     def child(self, name):
 | |
|         """Try to find name in childs without recursion"""
 | |
|         for child in self.childs:
 | |
|             if child.name == name:
 | |
|                 return child
 | |
|         return None
 | |
| 
 | |
|     def find(self, obj):
 | |
|         """Try to find obj in childs and their own child (with recursion)"""
 | |
|         if self == obj:
 | |
|             return self
 | |
|         
 | |
|         for child in self.childs:
 | |
|             res = child.find(obj)
 | |
|             if res:
 | |
|                 return res
 | |
|         return None
 | |
| 
 | |
|     def fillFrom(self, other):
 | |
|         """Copy all childs from other object"""
 | |
|         for child in other.childs:
 | |
|             self.childs.append(child)
 | |
| 
 | |
|     def getDependencies(self):
 | |
|         """Get dependencies from other namespaces"""
 | |
|         dependencies = []
 | |
| 
 | |
|         for child in self.childs:
 | |
|             depend = child.getDependencies()
 | |
|             if depend:
 | |
|                 for d in depend:
 | |
|                     if not d.startswith('{}::'.format(self.name)):
 | |
|                         dependencies.append(d)
 | |
| 
 | |
|         if dependencies:
 | |
|             return list(set(dependencies))
 | |
|         return []
 | |
| 
 | |
|     def __str__(self):
 | |
|         global cur_indent
 | |
|         if self.name != 'global':
 | |
|             res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name)
 | |
|             cur_indent += 1
 | |
|         else:
 | |
|             res = ''
 | |
| 
 | |
|         namespaces = []
 | |
|         classes = []
 | |
|         functions = []
 | |
|         other = []
 | |
|         for child in self.childs:
 | |
|             if type(child) == Namespace: namespaces.append(child)
 | |
|             elif type(child) == Class: classes.append(child)
 | |
|             elif type(child) == Function: functions.append(child)
 | |
|             else: other.append(child)
 | |
| 
 | |
|         # Compute classes inheritance dependency
 | |
|         classes_dep = []
 | |
|         for class_ in sorted(classes):
 | |
|             isDep = False
 | |
|             for pos, class2 in enumerate(classes_dep):
 | |
|                 if class_ in class2.inherit_from:
 | |
|                     isDep = True
 | |
|                     classes_dep.insert(pos, class_)
 | |
|                     break
 | |
|             if not isDep:
 | |
|                 classes_dep.append(class_)
 | |
|                     
 | |
|         
 | |
|         # Add class declaration
 | |
|         if len(classes_dep) > 1:
 | |
|             for c in classes_dep:
 | |
|                 res += '{}class {};\n'.format(_getIndent(), c.name)
 | |
|             if classes_dep: res += '\n\n'
 | |
| 
 | |
|         for namespace in sorted(namespaces):
 | |
|             res += namespace.__str__()
 | |
|         if namespaces: res += '\n'
 | |
| 
 | |
|         for c in classes_dep:
 | |
|             res += c.__str__()
 | |
|         if classes_dep: res += '\n'
 | |
|     
 | |
|         for func in sorted(functions):
 | |
|             res += func.__str__()
 | |
|         if functions: res += '\n'
 | |
| 
 | |
|         for child in sorted(other):
 | |
|             res += child.__str__()
 | |
|         if other: res += '\n'
 | |
| 
 | |
|         if self.name != 'global':
 | |
|             cur_indent -= 1
 | |
|             res += '{}}}\n'.format(_getIndent())
 | |
|         res += '\n'
 | |
|         
 | |
|         return res
 | |
|         
 | |
| class Class(Namespace):
 | |
|     """Class description"""
 | |
|     def __init__(self, name, namespace=''):
 | |
|         Namespace.__init__(self, name)
 | |
| 
 | |
|         self.constructors = []
 | |
|         self.destructors = []
 | |
|         self.inherit_from = []
 | |
|         self.virtual_functions = []
 | |
|         self.namespace = namespace
 | |
| 
 | |
|     def fullname(self):
 | |
|         """Return namespace::name"""
 | |
|         if self.namespace:
 | |
|             return '{}::{}'.format(self.namespace, self.name)
 | |
|         else:
 | |
|             return self.name
 | |
|         
 | |
|     def _checkConstructor(self, obj):
 | |
|         """Check if obj is a constructor/destructor and
 | |
|         set its property.
 | |
|         
 | |
|         Returns
 | |
|         -------
 | |
|         list
 | |
|             Adequat list (constructor or destructor) or None
 | |
|         """
 | |
|         if type(obj) != Function: return None
 | |
|         if obj.name.startswith('~'):
 | |
|             # sys.stderr.write('Check C {} -> D\n'.format(obj.name))
 | |
|             obj.setConstructor(True)
 | |
|             return self.destructors
 | |
|         if obj.name.startswith('{}('.format(self.name)):
 | |
|             # sys.stderr.write('Check C {} -> C\n'.format(obj.name))
 | |
|             obj.setConstructor(True)
 | |
|             return self.constructors
 | |
|         # sys.stderr.write('Check C {} -> N\n'.format(obj.name))
 | |
|         return None
 | |
|         
 | |
|     def addVirtualFunction(self, obj):
 | |
|         """Add a new virtual function"""
 | |
|         if obj.address == 0 or not obj in self.virtual_functions:
 | |
|             self._checkConstructor(obj)
 | |
|             self.virtual_functions.append(obj)
 | |
| 
 | |
|     def updateVirtualFunction(self, idx, obj):
 | |
|         """Update virtual function at index idx"""
 | |
|         try:
 | |
|             self._checkConstructor(obj)
 | |
|             self.virtual_functions[idx] = obj
 | |
|         except:
 | |
|             sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj))
 | |
|             sys.stderr.flush()
 | |
|         
 | |
|     def addMember(self, obj):
 | |
|         """Add a new member"""
 | |
|         if obj in self.virtual_functions or\
 | |
|            obj in self.constructors or\
 | |
|            obj in self.destructors or\
 | |
|            obj in self.childs:
 | |
|             return
 | |
| 
 | |
|         targetList = self._checkConstructor(obj)
 | |
|         if targetList is None:
 | |
|             self.childs.append(obj)
 | |
|         else:
 | |
|             targetList.append(obj)
 | |
| 
 | |
|     def addChild(self, child):
 | |
|         return self.addMember(child)
 | |
|     
 | |
|     def addBaseClass(self, obj):
 | |
|         self.inherit_from.append(obj)
 | |
| 
 | |
|     def fixVirtualFunction(self, index, newName):
 | |
|         """Set real name for virtfunc and unknown virtfunc
 | |
|         generic names
 | |
|         """
 | |
|         if index >= len(self.virtual_functions):
 | |
|             sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname()))
 | |
|             return False
 | |
|         
 | |
|         virtfunc = self.virtual_functions[index]
 | |
|         
 | |
|         if not virtfunc.isPure():
 | |
|             return False
 | |
|         
 | |
|         if virtfunc.name.startswith('virtfunc') or\
 | |
|            virtfunc.name.startswith('unknown_virtfunc'):
 | |
|             virtfunc.name = newName
 | |
|             self._checkConstructor(virtfunc)
 | |
|             return True
 | |
|         
 | |
|         return False
 | |
|     
 | |
|     def hasMultipleVirtualSections(self):
 | |
|         """Check if we have a vtable_indexX (X>0) entry in our
 | |
|         virtual functions table
 | |
|         """
 | |
|         if not self.virtual_functions:
 | |
|             return False
 | |
| 
 | |
|         for vfunc in self.virtual_functions:
 | |
|             if vfunc.name.startswith('vtable_index') and\
 | |
|                vfunc.name != 'vtable_index0':
 | |
|                 return True
 | |
|         return False
 | |
|     
 | |
|     def fixupInheritance(self):
 | |
|         """Report virtual function name in base class
 | |
|         if there are pure virtual.
 | |
|         """
 | |
|         if not self.inherit_from:
 | |
|             return
 | |
| 
 | |
|         # First, handle all of our bases
 | |
|         for base in self.inherit_from:
 | |
|             base.fixupInheritance()
 | |
| 
 | |
|         updated = False        
 | |
|         curIdx = 0
 | |
|         if self.hasMultipleVirtualSections():
 | |
|             for base in self.inherit_from:
 | |
|                 # First is vtable_index, skip it
 | |
|                 curIdx += 1
 | |
|                 targetIdx = 1
 | |
|                 updated = False
 | |
|                 while curIdx < len(base.virtual_functions):
 | |
|                     vfunc = self.virtual_functions[curIdx]
 | |
|                     if vfunc.name.startswith('virtual_index'):
 | |
|                         break
 | |
|                     if base.fixVirtualFunction(targetIdx, vfunc.name):
 | |
|                         updated = True
 | |
|                     curIdx += 1
 | |
|                     targetIdx += 1
 | |
|                 if updated:
 | |
|                     base.fixupInheritance()
 | |
|                 # Go to next vtable index if we are not
 | |
|                 while curIdx < len(self.virtual_functions):
 | |
|                     vfunc = self.virtual_functions[curIdx]
 | |
|                     if vfunc.name.startswith('virtual_index'):
 | |
|                         break
 | |
|                     curIdx += 1
 | |
|         else:
 | |
|             for base in self.inherit_from:
 | |
|                 targetIdx = 0
 | |
|                 while targetIdx < len(base.virtual_functions):
 | |
|                     vfunc = self.virtual_functions[curIdx]
 | |
|                     if base.fixVirtualFunction(targetIdx, vfunc.name):
 | |
|                         updated = True
 | |
|                     curIdx += 1
 | |
|                     targetIdx += 1
 | |
|                 if updated:
 | |
|                     base.fixupInheritance()
 | |
|                 
 | |
|     def looksLikeNamespace(self):
 | |
|         """Empty specific class attributes looks like a namespace
 | |
|         """
 | |
|         return not self.constructors and\
 | |
|             not self.destructors and\
 | |
|             not self.inherit_from and\
 | |
|             not self.virtual_functions
 | |
| 
 | |
|     def _getDependencies(self, targetList, dependencies):
 | |
|         for obj in targetList:
 | |
|             res = obj.getDependencies()
 | |
|             if res:
 | |
|                 for d in res:
 | |
|                     if self.namespace and not d.startswith(self.namespace):
 | |
|                         dependencies.append(d)
 | |
|         
 | |
|     def getDependencies(self):
 | |
|         """Get dependencies from other namespaces"""
 | |
|         dependencies = []
 | |
| 
 | |
|         for base in self.inherit_from:
 | |
|             if base.namespace:
 | |
|                 dependencies.append(base.fullname())
 | |
| 
 | |
|         self._getDependencies(self.constructors, dependencies)
 | |
|         self._getDependencies(self.destructors, dependencies)
 | |
|         self._getDependencies(self.virtual_functions, dependencies)
 | |
|         self._getDependencies(self.childs, dependencies)
 | |
| 
 | |
|         if dependencies:
 | |
|             return list(set(dependencies))
 | |
|         return None
 | |
| 
 | |
|     def printOverloadedVirtualTable(self):
 | |
|         """Only select overloaded methods from
 | |
|         virtual table
 | |
|         """
 | |
|         res = ''
 | |
|         vfunc_res = ''
 | |
|         for vfunc in self.virtual_functions:
 | |
|             if vfunc.name.startswith('vtable_index') or\
 | |
|                vfunc.name.startswith('typeinfo()'):
 | |
|                 continue
 | |
| 
 | |
|             # Not overloaded by us
 | |
|             if vfunc.namespace and vfunc.namespace != self.namespace:
 | |
|                 continue
 | |
| 
 | |
|             vfunc_res = vfunc.__str__()
 | |
|             
 | |
|             # Method already in result,
 | |
|             # It's the case for virtual descriptor
 | |
|             if vfunc_res in res:
 | |
|                 continue
 | |
| 
 | |
|             res += vfunc_res
 | |
|                
 | |
|         return res
 | |
|     
 | |
|     def __str__(self):
 | |
|         global print_raw_virtual_table
 | |
|         global cur_indent
 | |
|         res = '{}class {}'.format(_getIndent(), self.name)
 | |
|         if self.inherit_from:
 | |
|             res += ': '
 | |
|             bases = []
 | |
|             for base in self.inherit_from:
 | |
|                  bases.append('public {}'.format(base.fullname()))
 | |
|             res += ', '.join(bases)
 | |
|         res += '\n{}{{\n'.format(_getIndent())
 | |
|         res += '{}public:\n'.format(_getIndent())
 | |
|         cur_indent += 1
 | |
|         
 | |
|         for constructor in sorted(self.constructors):
 | |
|             res += constructor.__str__()
 | |
|         if len(self.constructors): res += '\n'
 | |
|         
 | |
|         for destructor in sorted(self.destructors):
 | |
|             res += destructor.__str__()
 | |
|         if len(self.destructors): res += '\n'
 | |
| 
 | |
|         # Do not sort virtual tables !
 | |
|         virtfuncs = ''
 | |
|         if print_raw_virtual_table:
 | |
|             for virtfunc in self.virtual_functions:
 | |
|                 virtfuncs += virtfunc.__str__()
 | |
|         else:
 | |
|             virtfuncs = self.printOverloadedVirtualTable()
 | |
|             
 | |
|         if len(virtfuncs): res += '{}\n'.format(virtfuncs)
 | |
| 
 | |
|         methods = []
 | |
|         other = []
 | |
|         
 | |
|         for child in self.childs:
 | |
|             if type(child) == Function: methods.append(child)
 | |
|             else: other.append(child)
 | |
|         
 | |
|         for method in sorted(methods):
 | |
|             res += method.__str__()
 | |
|         if methods: res += '\n'
 | |
|             
 | |
|         for child in sorted(other):
 | |
|             res += child.__str__()
 | |
|         if other: res += '\n'
 | |
| 
 | |
|         cur_indent -= 1
 | |
|         res += '{}}};\n\n'.format(_getIndent())
 | |
|         
 | |
|         return res
 |