537 lines
17 KiB
Python
537 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright Grégory Soutadé
|
|
|
|
# This file is part of SOAdvancedDissector
|
|
|
|
# SOAdvancedDissector is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# SOAdvancedDissector is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with SOAdvancedDissector. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
|
|
import sys
|
|
from cppprototypeparser import CPPPrototypeParser
|
|
|
|
print_raw_virtual_table = False
|
|
print_indent = False
|
|
cur_indent = 0
|
|
|
|
parser = CPPPrototypeParser()
|
|
|
|
def setPrintRawVirtualTable(value):
|
|
global print_raw_virtual_table
|
|
print_raw_virtual_table = value
|
|
|
|
def setPrintIndent(value):
|
|
global print_indent
|
|
print_indent = value
|
|
|
|
def _getIndent():
|
|
global print_indent
|
|
indent = ''
|
|
if print_indent:
|
|
indent = ' '*cur_indent
|
|
return indent
|
|
|
|
class Object:
|
|
"""Abstract object representation (simply a name)"""
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def find(self, obj):
|
|
if self == obj:
|
|
return self
|
|
return None
|
|
|
|
def getParametersDependencies(self):
|
|
return None
|
|
|
|
def getDependencies(self):
|
|
"""Get dependencies from other namespaces"""
|
|
return None
|
|
|
|
def __eq__(self, other):
|
|
if type(other) == str:
|
|
return self.name == other
|
|
else:
|
|
return self.name == other.name
|
|
|
|
def __lt__(self, other):
|
|
return self.name < other.name
|
|
|
|
class Attribute(Object):
|
|
"""Class attribute (member)"""
|
|
def __init__(self, name, address=0, namespace=''):
|
|
Object.__init__(self, name)
|
|
self.address = address
|
|
self.namespace = namespace
|
|
|
|
def fullname(self):
|
|
"""Return namespace::name"""
|
|
if self.namespace:
|
|
return '{}::{}'.format(self.namespace, self.name)
|
|
else:
|
|
return self.name
|
|
|
|
def __eq__(self, other):
|
|
if type(other) == str:
|
|
return self.name == other
|
|
else:
|
|
if hasattr(other, 'address'):
|
|
return self.address == other.address
|
|
else:
|
|
return self.name == other.name
|
|
|
|
def __str__(self):
|
|
return '{}void* {};\n'.format(_getIndent(), self.name)
|
|
|
|
|
|
class Function(Attribute):
|
|
"""Function description"""
|
|
def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''):
|
|
# Be sure we have () in name
|
|
if '(' in name and not name.endswith(')') and not name.endswith('const')\
|
|
and not name.endswith('const&'):
|
|
name += ')'
|
|
Attribute.__init__(self, name, address, namespace)
|
|
self.virtual = virtual
|
|
self.pure_virtual = pure_virtual
|
|
self.constructor = False
|
|
|
|
def isPure(self):
|
|
"""Is function pure virtual"""
|
|
return self.pure_virtual
|
|
|
|
def setConstructor(self, value):
|
|
"""Set/clear constructor property"""
|
|
self.constructor = value
|
|
|
|
def getDependencies(self):
|
|
"""Get dependencies from other namespaces"""
|
|
dependencies = []
|
|
parser.parse(self.name)
|
|
if not parser.has_parameters:
|
|
return None
|
|
for p in parser.parameters:
|
|
# If parameter has namespace, add into dependencies
|
|
if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]):
|
|
dep = '::'.join(p)
|
|
if not '::' in dep: continue
|
|
dep = dep.replace('*', '')
|
|
dep = dep.replace('&', '')
|
|
if dep.endswith(' const'):
|
|
dep = dep[:-len(' const')]
|
|
dependencies.append(dep)
|
|
if dependencies:
|
|
return list(set(dependencies))
|
|
return None
|
|
|
|
def _getReturnType(self):
|
|
"""Get method return type"""
|
|
type_ = 'void '
|
|
if self.name.startswith('vtable_index'): type_='void* '
|
|
elif self.name.startswith('operator new'): type_='void* '
|
|
elif self.constructor: type_ = ''
|
|
return type_
|
|
|
|
def __str__(self):
|
|
res = ''
|
|
type_ = self._getReturnType()
|
|
if self.pure_virtual:
|
|
res = 'virtual {}{} = 0;\n'.format(type_, self.name)
|
|
elif self.virtual:
|
|
res = 'virtual {}{};\n'.format(type_, self.name)
|
|
else:
|
|
res = '{}{};\n'.format(type_, self.name)
|
|
res = '{}{}'.format(_getIndent(), res)
|
|
return res
|
|
|
|
class Namespace(Object):
|
|
"""Namespace description"""
|
|
def __init__(self, name):
|
|
Object.__init__(self, name)
|
|
self.childs = []
|
|
self.dependencies = [] # Dependencies from objects in other namespace
|
|
|
|
def addChild(self, child):
|
|
"""Add child (function, class, attribute) to namespace"""
|
|
if not child in self.childs:
|
|
self.childs.append(child)
|
|
|
|
def removeChild(self, child):
|
|
"""Remove child from namespace"""
|
|
self.childs.remove(child)
|
|
|
|
def child(self, name):
|
|
"""Try to find name in childs without recursion"""
|
|
for child in self.childs:
|
|
if child.name == name:
|
|
return child
|
|
return None
|
|
|
|
def find(self, obj):
|
|
"""Try to find obj in childs and their own child (with recursion)"""
|
|
if self == obj:
|
|
return self
|
|
|
|
for child in self.childs:
|
|
res = child.find(obj)
|
|
if res:
|
|
return res
|
|
return None
|
|
|
|
def fillFrom(self, other):
|
|
"""Copy all childs from other object"""
|
|
for child in other.childs:
|
|
self.childs.append(child)
|
|
|
|
def getDependencies(self):
|
|
"""Get dependencies from other namespaces"""
|
|
dependencies = []
|
|
|
|
for child in self.childs:
|
|
depend = child.getDependencies()
|
|
if depend:
|
|
for d in depend:
|
|
if not d.startswith('{}::'.format(self.name)):
|
|
dependencies.append(d)
|
|
|
|
if dependencies:
|
|
return list(set(dependencies))
|
|
return []
|
|
|
|
def __str__(self):
|
|
global cur_indent
|
|
if self.name != 'global':
|
|
res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name)
|
|
cur_indent += 1
|
|
else:
|
|
res = ''
|
|
|
|
namespaces = []
|
|
classes = []
|
|
functions = []
|
|
other = []
|
|
for child in self.childs:
|
|
if type(child) == Namespace: namespaces.append(child)
|
|
elif type(child) == Class: classes.append(child)
|
|
elif type(child) == Function: functions.append(child)
|
|
else: other.append(child)
|
|
|
|
# Compute classes inheritance dependency
|
|
classes_dep = []
|
|
for class_ in sorted(classes):
|
|
isDep = False
|
|
for pos, class2 in enumerate(classes_dep):
|
|
if class_ in class2.inherit_from:
|
|
isDep = True
|
|
classes_dep.insert(pos, class_)
|
|
break
|
|
if not isDep:
|
|
classes_dep.append(class_)
|
|
|
|
|
|
# Add class declaration
|
|
if len(classes_dep) > 1:
|
|
for c in classes_dep:
|
|
res += '{}class {};\n'.format(_getIndent(), c.name)
|
|
if classes_dep: res += '\n\n'
|
|
|
|
for namespace in sorted(namespaces):
|
|
res += namespace.__str__()
|
|
if namespaces: res += '\n'
|
|
|
|
for c in classes_dep:
|
|
res += c.__str__()
|
|
if classes_dep: res += '\n'
|
|
|
|
for func in sorted(functions):
|
|
res += func.__str__()
|
|
if functions: res += '\n'
|
|
|
|
for child in sorted(other):
|
|
res += child.__str__()
|
|
if other: res += '\n'
|
|
|
|
if self.name != 'global':
|
|
cur_indent -= 1
|
|
res += '{}}}\n'.format(_getIndent())
|
|
res += '\n'
|
|
|
|
return res
|
|
|
|
class Class(Namespace):
|
|
"""Class description"""
|
|
def __init__(self, name, namespace=''):
|
|
Namespace.__init__(self, name)
|
|
|
|
self.constructors = []
|
|
self.destructors = []
|
|
self.inherit_from = []
|
|
self.virtual_functions = []
|
|
self.namespace = namespace
|
|
|
|
def fullname(self):
|
|
"""Return namespace::name"""
|
|
if self.namespace:
|
|
return '{}::{}'.format(self.namespace, self.name)
|
|
else:
|
|
return self.name
|
|
|
|
def _checkConstructor(self, obj):
|
|
"""Check if obj is a constructor/destructor and
|
|
set its property.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
Adequat list (constructor or destructor) or None
|
|
"""
|
|
if type(obj) != Function: return None
|
|
if obj.name.startswith('~'):
|
|
# sys.stderr.write('Check C {} -> D\n'.format(obj.name))
|
|
obj.setConstructor(True)
|
|
return self.destructors
|
|
if obj.name.startswith('{}('.format(self.name)):
|
|
# sys.stderr.write('Check C {} -> C\n'.format(obj.name))
|
|
obj.setConstructor(True)
|
|
return self.constructors
|
|
# sys.stderr.write('Check C {} -> N\n'.format(obj.name))
|
|
return None
|
|
|
|
def addVirtualFunction(self, obj):
|
|
"""Add a new virtual function"""
|
|
if obj.address == 0 or not obj in self.virtual_functions:
|
|
self._checkConstructor(obj)
|
|
self.virtual_functions.append(obj)
|
|
|
|
def updateVirtualFunction(self, idx, obj):
|
|
"""Update virtual function at index idx"""
|
|
try:
|
|
self._checkConstructor(obj)
|
|
self.virtual_functions[idx] = obj
|
|
except:
|
|
sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj))
|
|
sys.stderr.flush()
|
|
|
|
def addMember(self, obj):
|
|
"""Add a new member"""
|
|
if obj in self.virtual_functions or\
|
|
obj in self.constructors or\
|
|
obj in self.destructors or\
|
|
obj in self.childs:
|
|
return
|
|
|
|
targetList = self._checkConstructor(obj)
|
|
if targetList is None:
|
|
self.childs.append(obj)
|
|
else:
|
|
targetList.append(obj)
|
|
|
|
def addChild(self, child):
|
|
return self.addMember(child)
|
|
|
|
def addBaseClass(self, obj):
|
|
self.inherit_from.append(obj)
|
|
|
|
def fixVirtualFunction(self, index, newName):
|
|
"""Set real name for virtfunc and unknown virtfunc
|
|
generic names
|
|
"""
|
|
if index >= len(self.virtual_functions):
|
|
sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname()))
|
|
return False
|
|
|
|
virtfunc = self.virtual_functions[index]
|
|
|
|
if not virtfunc.isPure():
|
|
return False
|
|
|
|
if virtfunc.name.startswith('virtfunc') or\
|
|
virtfunc.name.startswith('unknown_virtfunc'):
|
|
virtfunc.name = newName
|
|
self._checkConstructor(virtfunc)
|
|
return True
|
|
|
|
return False
|
|
|
|
def hasMultipleVirtualSections(self):
|
|
"""Check if we have a vtable_indexX (X>0) entry in our
|
|
virtual functions table
|
|
"""
|
|
if not self.virtual_functions:
|
|
return False
|
|
|
|
for vfunc in self.virtual_functions:
|
|
if vfunc.name.startswith('vtable_index') and\
|
|
vfunc.name != 'vtable_index0':
|
|
return True
|
|
return False
|
|
|
|
def fixupInheritance(self):
|
|
"""Report virtual function name in base class
|
|
if there are pure virtual.
|
|
"""
|
|
if not self.inherit_from:
|
|
return
|
|
|
|
# First, handle all of our bases
|
|
for base in self.inherit_from:
|
|
base.fixupInheritance()
|
|
|
|
updated = False
|
|
curIdx = 0
|
|
if self.hasMultipleVirtualSections():
|
|
for base in self.inherit_from:
|
|
# First is vtable_index, skip it
|
|
curIdx += 1
|
|
targetIdx = 1
|
|
updated = False
|
|
while curIdx < len(base.virtual_functions):
|
|
vfunc = self.virtual_functions[curIdx]
|
|
if vfunc.name.startswith('virtual_index'):
|
|
break
|
|
if base.fixVirtualFunction(targetIdx, vfunc.name):
|
|
updated = True
|
|
curIdx += 1
|
|
targetIdx += 1
|
|
if updated:
|
|
base.fixupInheritance()
|
|
# Go to next vtable index if we are not
|
|
while curIdx < len(self.virtual_functions):
|
|
vfunc = self.virtual_functions[curIdx]
|
|
if vfunc.name.startswith('virtual_index'):
|
|
break
|
|
curIdx += 1
|
|
else:
|
|
for base in self.inherit_from:
|
|
targetIdx = 0
|
|
while targetIdx < len(base.virtual_functions):
|
|
vfunc = self.virtual_functions[curIdx]
|
|
if base.fixVirtualFunction(targetIdx, vfunc.name):
|
|
updated = True
|
|
curIdx += 1
|
|
targetIdx += 1
|
|
if updated:
|
|
base.fixupInheritance()
|
|
|
|
def looksLikeNamespace(self):
|
|
"""Empty specific class attributes looks like a namespace
|
|
"""
|
|
return not self.constructors and\
|
|
not self.destructors and\
|
|
not self.inherit_from and\
|
|
not self.virtual_functions
|
|
|
|
def _getDependencies(self, targetList, dependencies):
|
|
for obj in targetList:
|
|
res = obj.getDependencies()
|
|
if res:
|
|
for d in res:
|
|
if self.namespace and not d.startswith(self.namespace):
|
|
dependencies.append(d)
|
|
|
|
def getDependencies(self):
|
|
"""Get dependencies from other namespaces"""
|
|
dependencies = []
|
|
|
|
for base in self.inherit_from:
|
|
if base.namespace:
|
|
dependencies.append(base.fullname())
|
|
|
|
self._getDependencies(self.constructors, dependencies)
|
|
self._getDependencies(self.destructors, dependencies)
|
|
self._getDependencies(self.virtual_functions, dependencies)
|
|
self._getDependencies(self.childs, dependencies)
|
|
|
|
if dependencies:
|
|
return list(set(dependencies))
|
|
return None
|
|
|
|
def printOverloadedVirtualTable(self):
|
|
"""Only select overloaded methods from
|
|
virtual table
|
|
"""
|
|
res = ''
|
|
vfunc_res = ''
|
|
for vfunc in self.virtual_functions:
|
|
if vfunc.name.startswith('vtable_index') or\
|
|
vfunc.name.startswith('typeinfo()'):
|
|
continue
|
|
|
|
# Not overloaded by us
|
|
if vfunc.namespace and vfunc.namespace != self.namespace:
|
|
continue
|
|
|
|
vfunc_res = vfunc.__str__()
|
|
|
|
# Method already in result,
|
|
# It's the case for virtual descriptor
|
|
if vfunc_res in res:
|
|
continue
|
|
|
|
res += vfunc_res
|
|
|
|
return res
|
|
|
|
def __str__(self):
|
|
global print_raw_virtual_table
|
|
global cur_indent
|
|
res = '{}class {}'.format(_getIndent(), self.name)
|
|
if self.inherit_from:
|
|
res += ': '
|
|
bases = []
|
|
for base in self.inherit_from:
|
|
bases.append('public {}'.format(base.fullname()))
|
|
res += ', '.join(bases)
|
|
res += '\n{}{{\n'.format(_getIndent())
|
|
res += '{}public:\n'.format(_getIndent())
|
|
cur_indent += 1
|
|
|
|
for constructor in sorted(self.constructors):
|
|
res += constructor.__str__()
|
|
if len(self.constructors): res += '\n'
|
|
|
|
for destructor in sorted(self.destructors):
|
|
res += destructor.__str__()
|
|
if len(self.destructors): res += '\n'
|
|
|
|
# Do not sort virtual tables !
|
|
virtfuncs = ''
|
|
if print_raw_virtual_table:
|
|
for virtfunc in self.virtual_functions:
|
|
virtfuncs += virtfunc.__str__()
|
|
else:
|
|
virtfuncs = self.printOverloadedVirtualTable()
|
|
|
|
if len(virtfuncs): res += '{}\n'.format(virtfuncs)
|
|
|
|
methods = []
|
|
other = []
|
|
|
|
for child in self.childs:
|
|
if type(child) == Function: methods.append(child)
|
|
else: other.append(child)
|
|
|
|
for method in sorted(methods):
|
|
res += method.__str__()
|
|
if methods: res += '\n'
|
|
|
|
for child in sorted(other):
|
|
res += child.__str__()
|
|
if other: res += '\n'
|
|
|
|
cur_indent -= 1
|
|
res += '{}}};\n\n'.format(_getIndent())
|
|
|
|
return res
|