Initial commit
This commit is contained in:
536
objects.py
Normal file
536
objects.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright Grégory Soutadé
|
||||
|
||||
# This file is part of SOAdvancedDissector
|
||||
|
||||
# SOAdvancedDissector is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# SOAdvancedDissector is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with SOAdvancedDissector. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
|
||||
import sys
|
||||
from cppprototypeparser import CPPPrototypeParser
|
||||
|
||||
print_raw_virtual_table = False
|
||||
print_indent = False
|
||||
cur_indent = 0
|
||||
|
||||
parser = CPPPrototypeParser()
|
||||
|
||||
def setPrintRawVirtualTable(value):
|
||||
global print_raw_virtual_table
|
||||
print_raw_virtual_table = value
|
||||
|
||||
def setPrintIndent(value):
|
||||
global print_indent
|
||||
print_indent = value
|
||||
|
||||
def _getIndent():
|
||||
global print_indent
|
||||
indent = ''
|
||||
if print_indent:
|
||||
indent = ' '*cur_indent
|
||||
return indent
|
||||
|
||||
class Object:
|
||||
"""Abstract object representation (simply a name)"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def find(self, obj):
|
||||
if self == obj:
|
||||
return self
|
||||
return None
|
||||
|
||||
def getParametersDependencies(self):
|
||||
return None
|
||||
|
||||
def getDependencies(self):
|
||||
"""Get dependencies from other namespaces"""
|
||||
return None
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) == str:
|
||||
return self.name == other
|
||||
else:
|
||||
return self.name == other.name
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.name < other.name
|
||||
|
||||
class Attribute(Object):
|
||||
"""Class attribute (member)"""
|
||||
def __init__(self, name, address=0, namespace=''):
|
||||
Object.__init__(self, name)
|
||||
self.address = address
|
||||
self.namespace = namespace
|
||||
|
||||
def fullname(self):
|
||||
"""Return namespace::name"""
|
||||
if self.namespace:
|
||||
return '{}::{}'.format(self.namespace, self.name)
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) == str:
|
||||
return self.name == other
|
||||
else:
|
||||
if hasattr(other, 'address'):
|
||||
return self.address == other.address
|
||||
else:
|
||||
return self.name == other.name
|
||||
|
||||
def __str__(self):
|
||||
return '{}void* {};\n'.format(_getIndent(), self.name)
|
||||
|
||||
|
||||
class Function(Attribute):
|
||||
"""Function description"""
|
||||
def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''):
|
||||
# Be sure we have () in name
|
||||
if '(' in name and not name.endswith(')') and not name.endswith('const')\
|
||||
and not name.endswith('const&'):
|
||||
name += ')'
|
||||
Attribute.__init__(self, name, address, namespace)
|
||||
self.virtual = virtual
|
||||
self.pure_virtual = pure_virtual
|
||||
self.constructor = False
|
||||
|
||||
def isPure(self):
|
||||
"""Is function pure virtual"""
|
||||
return self.pure_virtual
|
||||
|
||||
def setConstructor(self, value):
|
||||
"""Set/clear constructor property"""
|
||||
self.constructor = value
|
||||
|
||||
def getDependencies(self):
|
||||
"""Get dependencies from other namespaces"""
|
||||
dependencies = []
|
||||
parser.parse(self.name)
|
||||
if not parser.has_parameters:
|
||||
return None
|
||||
for p in parser.parameters:
|
||||
# If parameter has namespace, add into dependencies
|
||||
if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]):
|
||||
dep = '::'.join(p)
|
||||
if not '::' in dep: continue
|
||||
dep = dep.replace('*', '')
|
||||
dep = dep.replace('&', '')
|
||||
if dep.endswith(' const'):
|
||||
dep = dep[:-len(' const')]
|
||||
dependencies.append(dep)
|
||||
if dependencies:
|
||||
return list(set(dependencies))
|
||||
return None
|
||||
|
||||
def _getReturnType(self):
|
||||
"""Get method return type"""
|
||||
type_ = 'void '
|
||||
if self.name.startswith('vtable_index'): type_='void* '
|
||||
elif self.name.startswith('operator new'): type_='void* '
|
||||
elif self.constructor: type_ = ''
|
||||
return type_
|
||||
|
||||
def __str__(self):
|
||||
res = ''
|
||||
type_ = self._getReturnType()
|
||||
if self.pure_virtual:
|
||||
res = 'virtual {}{} = 0;\n'.format(type_, self.name)
|
||||
elif self.virtual:
|
||||
res = 'virtual {}{};\n'.format(type_, self.name)
|
||||
else:
|
||||
res = '{}{};\n'.format(type_, self.name)
|
||||
res = '{}{}'.format(_getIndent(), res)
|
||||
return res
|
||||
|
||||
class Namespace(Object):
|
||||
"""Namespace description"""
|
||||
def __init__(self, name):
|
||||
Object.__init__(self, name)
|
||||
self.childs = []
|
||||
self.dependencies = [] # Dependencies from objects in other namespace
|
||||
|
||||
def addChild(self, child):
|
||||
"""Add child (function, class, attribute) to namespace"""
|
||||
if not child in self.childs:
|
||||
self.childs.append(child)
|
||||
|
||||
def removeChild(self, child):
|
||||
"""Remove child from namespace"""
|
||||
self.childs.remove(child)
|
||||
|
||||
def child(self, name):
|
||||
"""Try to find name in childs without recursion"""
|
||||
for child in self.childs:
|
||||
if child.name == name:
|
||||
return child
|
||||
return None
|
||||
|
||||
def find(self, obj):
|
||||
"""Try to find obj in childs and their own child (with recursion)"""
|
||||
if self == obj:
|
||||
return self
|
||||
|
||||
for child in self.childs:
|
||||
res = child.find(obj)
|
||||
if res:
|
||||
return res
|
||||
return None
|
||||
|
||||
def fillFrom(self, other):
|
||||
"""Copy all childs from other object"""
|
||||
for child in other.childs:
|
||||
self.childs.append(child)
|
||||
|
||||
def getDependencies(self):
|
||||
"""Get dependencies from other namespaces"""
|
||||
dependencies = []
|
||||
|
||||
for child in self.childs:
|
||||
depend = child.getDependencies()
|
||||
if depend:
|
||||
for d in depend:
|
||||
if not d.startswith('{}::'.format(self.name)):
|
||||
dependencies.append(d)
|
||||
|
||||
if dependencies:
|
||||
return list(set(dependencies))
|
||||
return []
|
||||
|
||||
def __str__(self):
|
||||
global cur_indent
|
||||
if self.name != 'global':
|
||||
res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name)
|
||||
cur_indent += 1
|
||||
else:
|
||||
res = ''
|
||||
|
||||
namespaces = []
|
||||
classes = []
|
||||
functions = []
|
||||
other = []
|
||||
for child in self.childs:
|
||||
if type(child) == Namespace: namespaces.append(child)
|
||||
elif type(child) == Class: classes.append(child)
|
||||
elif type(child) == Function: functions.append(child)
|
||||
else: other.append(child)
|
||||
|
||||
# Compute classes inheritance dependency
|
||||
classes_dep = []
|
||||
for class_ in sorted(classes):
|
||||
isDep = False
|
||||
for pos, class2 in enumerate(classes_dep):
|
||||
if class_ in class2.inherit_from:
|
||||
isDep = True
|
||||
classes_dep.insert(pos, class_)
|
||||
break
|
||||
if not isDep:
|
||||
classes_dep.append(class_)
|
||||
|
||||
|
||||
# Add class declaration
|
||||
if len(classes_dep) > 1:
|
||||
for c in classes_dep:
|
||||
res += '{}class {};\n'.format(_getIndent(), c.name)
|
||||
if classes_dep: res += '\n\n'
|
||||
|
||||
for namespace in sorted(namespaces):
|
||||
res += namespace.__str__()
|
||||
if namespaces: res += '\n'
|
||||
|
||||
for c in classes_dep:
|
||||
res += c.__str__()
|
||||
if classes_dep: res += '\n'
|
||||
|
||||
for func in sorted(functions):
|
||||
res += func.__str__()
|
||||
if functions: res += '\n'
|
||||
|
||||
for child in sorted(other):
|
||||
res += child.__str__()
|
||||
if other: res += '\n'
|
||||
|
||||
if self.name != 'global':
|
||||
cur_indent -= 1
|
||||
res += '{}}}\n'.format(_getIndent())
|
||||
res += '\n'
|
||||
|
||||
return res
|
||||
|
||||
class Class(Namespace):
|
||||
"""Class description"""
|
||||
def __init__(self, name, namespace=''):
|
||||
Namespace.__init__(self, name)
|
||||
|
||||
self.constructors = []
|
||||
self.destructors = []
|
||||
self.inherit_from = []
|
||||
self.virtual_functions = []
|
||||
self.namespace = namespace
|
||||
|
||||
def fullname(self):
|
||||
"""Return namespace::name"""
|
||||
if self.namespace:
|
||||
return '{}::{}'.format(self.namespace, self.name)
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def _checkConstructor(self, obj):
|
||||
"""Check if obj is a constructor/destructor and
|
||||
set its property.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
Adequat list (constructor or destructor) or None
|
||||
"""
|
||||
if type(obj) != Function: return None
|
||||
if obj.name.startswith('~'):
|
||||
# sys.stderr.write('Check C {} -> D\n'.format(obj.name))
|
||||
obj.setConstructor(True)
|
||||
return self.destructors
|
||||
if obj.name.startswith('{}('.format(self.name)):
|
||||
# sys.stderr.write('Check C {} -> C\n'.format(obj.name))
|
||||
obj.setConstructor(True)
|
||||
return self.constructors
|
||||
# sys.stderr.write('Check C {} -> N\n'.format(obj.name))
|
||||
return None
|
||||
|
||||
def addVirtualFunction(self, obj):
|
||||
"""Add a new virtual function"""
|
||||
if obj.address == 0 or not obj in self.virtual_functions:
|
||||
self._checkConstructor(obj)
|
||||
self.virtual_functions.append(obj)
|
||||
|
||||
def updateVirtualFunction(self, idx, obj):
|
||||
"""Update virtual function at index idx"""
|
||||
try:
|
||||
self._checkConstructor(obj)
|
||||
self.virtual_functions[idx] = obj
|
||||
except:
|
||||
sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj))
|
||||
sys.stderr.flush()
|
||||
|
||||
def addMember(self, obj):
|
||||
"""Add a new member"""
|
||||
if obj in self.virtual_functions or\
|
||||
obj in self.constructors or\
|
||||
obj in self.destructors or\
|
||||
obj in self.childs:
|
||||
return
|
||||
|
||||
targetList = self._checkConstructor(obj)
|
||||
if targetList is None:
|
||||
self.childs.append(obj)
|
||||
else:
|
||||
targetList.append(obj)
|
||||
|
||||
def addChild(self, child):
|
||||
return self.addMember(child)
|
||||
|
||||
def addBaseClass(self, obj):
|
||||
self.inherit_from.append(obj)
|
||||
|
||||
def fixVirtualFunction(self, index, newName):
|
||||
"""Set real name for virtfunc and unknown virtfunc
|
||||
generic names
|
||||
"""
|
||||
if index >= len(self.virtual_functions):
|
||||
sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname()))
|
||||
return False
|
||||
|
||||
virtfunc = self.virtual_functions[index]
|
||||
|
||||
if not virtfunc.isPure():
|
||||
return False
|
||||
|
||||
if virtfunc.name.startswith('virtfunc') or\
|
||||
virtfunc.name.startswith('unknown_virtfunc'):
|
||||
virtfunc.name = newName
|
||||
self._checkConstructor(virtfunc)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def hasMultipleVirtualSections(self):
|
||||
"""Check if we have a vtable_indexX (X>0) entry in our
|
||||
virtual functions table
|
||||
"""
|
||||
if not self.virtual_functions:
|
||||
return False
|
||||
|
||||
for vfunc in self.virtual_functions:
|
||||
if vfunc.name.startswith('vtable_index') and\
|
||||
vfunc.name != 'vtable_index0':
|
||||
return True
|
||||
return False
|
||||
|
||||
def fixupInheritance(self):
|
||||
"""Report virtual function name in base class
|
||||
if there are pure virtual.
|
||||
"""
|
||||
if not self.inherit_from:
|
||||
return
|
||||
|
||||
# First, handle all of our bases
|
||||
for base in self.inherit_from:
|
||||
base.fixupInheritance()
|
||||
|
||||
updated = False
|
||||
curIdx = 0
|
||||
if self.hasMultipleVirtualSections():
|
||||
for base in self.inherit_from:
|
||||
# First is vtable_index, skip it
|
||||
curIdx += 1
|
||||
targetIdx = 1
|
||||
updated = False
|
||||
while curIdx < len(base.virtual_functions):
|
||||
vfunc = self.virtual_functions[curIdx]
|
||||
if vfunc.name.startswith('virtual_index'):
|
||||
break
|
||||
if base.fixVirtualFunction(targetIdx, vfunc.name):
|
||||
updated = True
|
||||
curIdx += 1
|
||||
targetIdx += 1
|
||||
if updated:
|
||||
base.fixupInheritance()
|
||||
# Go to next vtable index if we are not
|
||||
while curIdx < len(self.virtual_functions):
|
||||
vfunc = self.virtual_functions[curIdx]
|
||||
if vfunc.name.startswith('virtual_index'):
|
||||
break
|
||||
curIdx += 1
|
||||
else:
|
||||
for base in self.inherit_from:
|
||||
targetIdx = 0
|
||||
while targetIdx < len(base.virtual_functions):
|
||||
vfunc = self.virtual_functions[curIdx]
|
||||
if base.fixVirtualFunction(targetIdx, vfunc.name):
|
||||
updated = True
|
||||
curIdx += 1
|
||||
targetIdx += 1
|
||||
if updated:
|
||||
base.fixupInheritance()
|
||||
|
||||
def looksLikeNamespace(self):
|
||||
"""Empty specific class attributes looks like a namespace
|
||||
"""
|
||||
return not self.constructors and\
|
||||
not self.destructors and\
|
||||
not self.inherit_from and\
|
||||
not self.virtual_functions
|
||||
|
||||
def _getDependencies(self, targetList, dependencies):
|
||||
for obj in targetList:
|
||||
res = obj.getDependencies()
|
||||
if res:
|
||||
for d in res:
|
||||
if self.namespace and not d.startswith(self.namespace):
|
||||
dependencies.append(d)
|
||||
|
||||
def getDependencies(self):
|
||||
"""Get dependencies from other namespaces"""
|
||||
dependencies = []
|
||||
|
||||
for base in self.inherit_from:
|
||||
if base.namespace:
|
||||
dependencies.append(base.fullname())
|
||||
|
||||
self._getDependencies(self.constructors, dependencies)
|
||||
self._getDependencies(self.destructors, dependencies)
|
||||
self._getDependencies(self.virtual_functions, dependencies)
|
||||
self._getDependencies(self.childs, dependencies)
|
||||
|
||||
if dependencies:
|
||||
return list(set(dependencies))
|
||||
return None
|
||||
|
||||
def printOverloadedVirtualTable(self):
|
||||
"""Only select overloaded methods from
|
||||
virtual table
|
||||
"""
|
||||
res = ''
|
||||
vfunc_res = ''
|
||||
for vfunc in self.virtual_functions:
|
||||
if vfunc.name.startswith('vtable_index') or\
|
||||
vfunc.name.startswith('typeinfo()'):
|
||||
continue
|
||||
|
||||
# Not overloaded by us
|
||||
if vfunc.namespace and vfunc.namespace != self.namespace:
|
||||
continue
|
||||
|
||||
vfunc_res = vfunc.__str__()
|
||||
|
||||
# Method already in result,
|
||||
# It's the case for virtual descriptor
|
||||
if vfunc_res in res:
|
||||
continue
|
||||
|
||||
res += vfunc_res
|
||||
|
||||
return res
|
||||
|
||||
def __str__(self):
|
||||
global print_raw_virtual_table
|
||||
global cur_indent
|
||||
res = '{}class {}'.format(_getIndent(), self.name)
|
||||
if self.inherit_from:
|
||||
res += ': '
|
||||
bases = []
|
||||
for base in self.inherit_from:
|
||||
bases.append('public {}'.format(base.fullname()))
|
||||
res += ', '.join(bases)
|
||||
res += '\n{}{{\n'.format(_getIndent())
|
||||
res += '{}public:\n'.format(_getIndent())
|
||||
cur_indent += 1
|
||||
|
||||
for constructor in sorted(self.constructors):
|
||||
res += constructor.__str__()
|
||||
if len(self.constructors): res += '\n'
|
||||
|
||||
for destructor in sorted(self.destructors):
|
||||
res += destructor.__str__()
|
||||
if len(self.destructors): res += '\n'
|
||||
|
||||
# Do not sort virtual tables !
|
||||
virtfuncs = ''
|
||||
if print_raw_virtual_table:
|
||||
for virtfunc in self.virtual_functions:
|
||||
virtfuncs += virtfunc.__str__()
|
||||
else:
|
||||
virtfuncs = self.printOverloadedVirtualTable()
|
||||
|
||||
if len(virtfuncs): res += '{}\n'.format(virtfuncs)
|
||||
|
||||
methods = []
|
||||
other = []
|
||||
|
||||
for child in self.childs:
|
||||
if type(child) == Function: methods.append(child)
|
||||
else: other.append(child)
|
||||
|
||||
for method in sorted(methods):
|
||||
res += method.__str__()
|
||||
if methods: res += '\n'
|
||||
|
||||
for child in sorted(other):
|
||||
res += child.__str__()
|
||||
if other: res += '\n'
|
||||
|
||||
cur_indent -= 1
|
||||
res += '{}}};\n\n'.format(_getIndent())
|
||||
|
||||
return res
|
Reference in New Issue
Block a user