Include Krakatau and enjarify resources directly

This commit is contained in:
Nico Mexis 2021-03-31 11:48:14 +02:00
parent 8cbe4301f7
commit a4b452de6a
No known key found for this signature in database
GPG Key ID: 27D6E17CE092AB78
308 changed files with 88612 additions and 19 deletions

Binary file not shown.

Binary file not shown.

27
pom.xml
View File

@ -5,6 +5,13 @@
<artifactId>bytecodeviewer</artifactId>
<version>2.9.23</version>
<properties>
<java.version>8</java.version>
<maven.compiler.target>${java.version}</maven.compiler.target>
<maven.compiler.source>${java.version}</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>com.android</groupId>
@ -109,13 +116,6 @@
<artifactId>dx</artifactId>
<version>1.16</version>
</dependency>
<dependency>
<groupId>enjarify</groupId>
<artifactId>enjarify</artifactId>
<version>3</version>
<scope>system</scope>
<systemPath>${project.basedir}/libs/enjarify-3.jar</systemPath>
</dependency>
<dependency>
<groupId>org.jboss.windup.decompiler</groupId>
<artifactId>decompiler-fernflower</artifactId>
@ -155,13 +155,6 @@
<artifactId>jgraphx</artifactId>
<version>3.4.1.3</version>
</dependency>
<dependency>
<groupId>krakatau</groupId>
<artifactId>krakatau</artifactId>
<version>11</version>
<scope>system</scope>
<systemPath>${project.basedir}/libs/Krakatau-11.jar</systemPath>
</dependency>
<dependency>
<groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId>
@ -220,8 +213,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>8</source>
<target>8</target>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
<plugin>
@ -229,7 +222,7 @@
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.2.0</version>
<configuration>
<source>8</source>
<source>${java.version}</source>
</configuration>
</plugin>
<plugin>

View File

@ -0,0 +1,8 @@
root = true
[*.py]
charset = utf-8
indent_style = space
indent_size = 4
insert_final_newline = true
end_of_line = lf

View File

@ -0,0 +1,2 @@
* text=auto
*.test binary

View File

@ -0,0 +1,5 @@
*.pyc
*.pyo
test*.py
Krakatau/plugins/*
tests/.cache

View File

@ -0,0 +1,295 @@
Krakatau Assembly Syntax
For a list of previous changes to the assembly syntax, see changelog.txt
Note: This documents the officially supported syntax of the assembler. The assembler accepts some files that don't fully conform to this syntax, but this behavior may change without warning in the future.
Lexical structure
Comments: Comments begin with a semicolon and go to the end of the line. Since no valid tokens start with a semicolon, there is no ambiguity. Comments are ignored during parsing.
Whitespace: At least one consecutive space or tab character
Lines that are empty except for whitespace or a comment are ignored. Many grammar productions require certain parts to be separated by a newline (LF/CRLF/CR). This is represented below by the terminal EOL. Due to the rules above, EOL can represent an optional comment, followed by a newline, followed by any number of empty/comment lines. There are no line continuations.
Integer, Long, Float, and Double literals use the same syntax as Java with a few differences:
* No underscores are allowed
* Doubles cannot be suffixed with d or D.
* Decimal floating point literals with values that cant be represented exactly in the target type arent guaranteed to round the same way as in Java. If the exact value is significant, you should use a hexidecimal floating point literal.
* If a decimal point is present, there must be at least one digit before and after it (0.5 is ok, but .5 is not. 5.0 is ok but 5. is not).
* A leading plus or minus sign is allowed.
* Only decimal and hexadecimal literals are allowed (no binary or octal)
* For doubles, special values can be represented by +Infinity, -Infinity, +NaN, and -NaN (case insensitive). For floats, these should be suffixed by f or F.
* NaNs with a specific binary representation can be represented by suffixing with the hexadecimal value in angle brackets. For example, -NaN<0x7ff0123456789abc> or +NaN<0xFFABCDEF>f
Note: NaN requires a leading sign, even though it is ignored. This is to avoid ambiguity with WORDs. The binary representation of a NaN with no explicit representation may be any valid encoding of NaN. If you care about the binary representation in the classfile, you should specify it explicitly as described above.
String literals use the same syntax as Java string literals with the following exceptions
* Non printable and non-ascii characters, including tabs, are not allowed. These can be represented by escape sequences as usual. For example \t means tab.
* Either single or double quotes can be used. If single quotes are used, double quotes can appear unescaped inside the string and vice versa.
* There are three additional types of escape sequences allowed: \xDD, \uDDDD, and \UDDDDDDDD where D is a hexadecimal digit. The later two are only allowed in unicode strings (see below). In the case of \U, the digits must correspond to a number less than 0x00110000. \x represents a byte or code point up to 255. \u represents a code point up to 65535. \U represents a code point up to 1114111 (0x10FFFF), which will be split into a surrogate pair when encoded if it is above 0xFFFF.
* There are two types of string literals - bytes and unicode. Unicode strings are the default and represent a sequence of code points which will be MUTF8 encoded when written to the classfile. A byte string, represented by prefixing with b or B, represents a raw sequence of bytes which will be written unchanged. For example, "\0" is encoded to a two byte sequence while b"\0" puts an actual null byte in the classfile (which is invalid, but potentially useful for testing).
Reference: The classfile format has a large number of places where an index is made into the constant pool or bootstrap methods table. The assembly format allows you to specify the definition inline, and the assembler will automatically add an entry as appropriate and fill in the index. However, this isnt acceptable in cases where the exact binary layout is important or where a definition is large and you want to refer to it many times without copying the definition each time.
For the first case, there are numeric references, designated by a decimal integer in square brackets with no leading zeroes. For example, [43] refers to the index 43 in the constant pool. For the second case, there are symbolic references, which is a sequence of lowercase ascii, digits, and underscores inside square brackets, not beginning with a digit. For example, [foo_bar4].
Bootstrap method references are the same except preceded by "bs:". For example, [bs:43] or [bs:foo_bar4]. These are represented by the terminal BSREF. Bootstrap method references are only used in very specific circumstances so you probably wont need them. All other references are constant pool references and have no prefix, designated by the terminal CPREF.
Note: Constant pools and bootstrap method tables are class-specific. So definitions inside one class do not affect any other classes assembled from the same source file.
Labels refer to a position within a methods bytecode. The assembler will automatically fill in each label with the calculated numerical offset. Labels consist of a capital L followed by ascii letters, digits, and underscores. A label definition (LBLDEF) is a label followed by a colon (with no space). For example, "LSTART:". Label uses are included in the WORD token type defined below since they dont have a colon.
Note: Labels refer to positions in the bytecode of the enclosing Code attribute where they appear. They may not appear outside of a Code attribute.
Word: A string beginning with a word_start_character, followed by zero or more word_start_character or word_rest_characters. Furthermore, if the first character is a [, it must be followed by another [ or a capital letter (A-Z).
word_start_character: a-z, A-Z, _, $, (, <, [
word_rest_character: 0-9, ), >, /, ;, *, +, -
Words are used to specify names, identifiers, descriptors, and so on. If you need to specify a name that cant be represented as a word (such as using forbidden characters), a string literal can be used instead. Words are represented in the grammar by the terminal WORD.
For example, 42 is not a valid word because it begins with a digit. A class named 42 can be defined as follows:
.class "42"
In addition, when used in a context following flags, words cannot be any of the possible flag names. These are currently public, private, protected, static, final, super, synchronized, open, transitive, volatile, bridge, static_phase, transient, varargs, native, interface, abstract, strict, synthetic, annotation, enum, module, and mandated. In addition, strictfp is disallowed to avoid confusion. So if you wanted to have a string field named bridge, youd have to do
.field "bridge" Ljava/lang/String;
Format of grammar rules
Nonterminals are specified in lowercase. Terminals with a specific value required are specified in quotes. e.g. "Foo" means that the exact text Foo (case sensitive) has to appear at that point. Terminals that require a value of a given token type are represented in all caps, e.g. EOL, INT_LITERAL, FLOAT_LITERAL, LONG_LITERAL, DOUBLE_LITERAL, STRING_LITERAL, CPREF, BSREF, WORD, LBLDEF.
*, +, ?, |, and () have their usual meanings in regular expressions.
Common constant rules
s8: INT_LITERAL
u8: INT_LITERAL
s16: INT_LITERAL
u16: INT_LITERAL
s32: INT_LITERAL
u32: INT_LITERAL
ident: WORD | STRING_LITERAL
utfref: CPREF | ident
clsref: CPREF | ident
natref: CPREF | ident utfref
fmimref: CPREF | fmim_tagged_const
bsref: BSREF | bsnotref
invdynref: CPREF | invdyn_tagged_const
handlecode: "getField" | "getStatic" | "putField" | "putStatic" | "invokeVirtual" | "invokeStatic" | "invokeSpecial" | "newInvokeSpecial" | "invokeInterface"
mhandlenotref: handlecode (CPREF | fmim_tagged_const)
mhandleref: CPREF | mhandlenotref
cmhmt_tagged_const: "Class" utfref | "MethodHandle" mhandlenotref | "MethodType" utfref
ilfds_tagged_const: "Integer" INT_LITERAL | "Float" FLOAT_LITERAL | "Long" LONG_LITERAL | "Double" DOUBLE_LITERAL | "String" STRING_LITERAL
simple_tagged_const: "Utf8" ident | "NameAndType" utfref utfref
fmim_tagged_const: ("Field" | "Method" | "InterfaceMethod") clsref natref
invdyn_tagged_const: "InvokeDynamic" bsref natref
ref_or_tagged_const_ldc: CPREF | cmhmt_tagged_const | ilfds_tagged_const
ref_or_tagged_const_all: ref_or_tagged_ldconst | simple_tagged_const | fmim_tagged_const | invdyn_tagged_const
bsnotref: mhandlenotref ref_or_tagged_const_ldc* ":"
ref_or_tagged_bootstrap: BSREF | "Bootstrap" bsnotref
Note: The most deeply nested possible valid constant is 6 levels (InvokeDynamic -> Bootstrap -> MethodHandle -> Method -> NameAndType -> Utf8). It is possible to create a more deeply nested constant definitions in this grammar by using references with invalid types, but the assembler may reject them.
ldc_rhs: CPREF | INT_LITERAL | FLOAT_LITERAL | LONG_LITERAL | DOUBLE_LITERAL | STRING_LITERAL | cmhmt_tagged_const
flag: "public" | "private" | "protected" | "static" | "final" | "super" | "synchronized" | "volatile" | "bridge" | "transient" | "varargs" | "native" | "interface" | "abstract" | "strict" | "synthetic" | "annotation" | "enum" | "mandated"
Basic assembly structure
assembly_file: EOL? class_definition*
class_definition: version? class_start class_item* class_end
version: ".version" u16 u16 EOL
class_start: class_directive super_directive interface_directive*
class_directive: ".class" flag* clsref EOL
super_directive: ".super" clsref EOL
interface_directive: ".implements" clsref EOL
class_end: ".end" "class" EOL
class_item: const_def | bootstrap_def | field_def | method_def | attribute
const_def: ".const" CPREF "=" ref_or_tagged_const_all EOL
bootstrap_def: ".bootstrap" BSREF "=" ref_or_tagged_bootstrap EOL
Note: If the right hand side is a reference, the left hand side must be a symbolic reference. For example, the following two are valid.
.const [foo] = [bar]
.const [foo] = [42]
While these are not valid.
.const [42] = [foo]
.const [42] = [32]
field_def: ".field" flag* utfref utfref initial_value? field_attributes? EOL
initial_value: "=" ldc_rhs
field_attributes: ".fieldattributes" EOL attribute* ".end" ".fieldattributes"
method_def: method_start (method_body | legacy_method_body) method_end
method_start: ".method" flag* utfref ":" utfref EOL
method_end: ".end" "method" EOL
method_body: attribute*
legacy_method_body: limit_directive+ code_body
limit_directive: ".limit" ("stack" | "locals") u16 EOL
Attributes
attribute: (named_attribute | generic_attribute) EOL
generic_attribute: ".attribute" utfref length_override? attribute_data
length_override: "length" u32
attribute_data: named_attribute | STRING_LITERAL
named_attribute: annotation_default | bootstrap_methods | code | constant_value | deprecated | enclosing_method | exceptions | inner_classes | line_number_table | local_variable_table | local_variable_type_table | method_parameters | runtime_annotations | runtime_visible_parameter_annotations | runtime_visible_type_annotations | signature | source_debug_extension | source_file | stack_map_table | synthetic
annotation_default: ".annotationdefault" element_value
bootstrap_methods: ".bootstrapmethods"
Note: The content of a BootstrapMethods attribute is automatically filled in based on the implicitly and explicitly defined bootstrap methods in the class. If this attributes contents are nonempty and the attribute isnt specified explicitly, one will be added implicitly. This means that you generally dont have to specify it. Its only useful if you care about the exact binary layout of the classfile.
code: code_start code_body code_end
code_start: ".code" "stack" code_limit_t "locals" code_limit_t EOL
code_limit_t: u8 | u16
code_end: ".end" "code"
Note: A Code attribute can only appear as a method attribute. This means that they cannot be nested.
constant_value: ".constantvalue" ldc_rhs
deprecated: ".deprecated"
enclosing_method: ".enclosing" "method" clsref natref
exceptions: ".exceptions" clsref*
inner_classes: ".innerclasses" EOL inner_classes_item* ".end" "innerclasses"
inner_classes_item: cpref cpref utfref flag* EOL
line_number_table: ".linenumbertable" EOL line_number* ".end" "linenumbertable"
line_number: label u16 EOL
local_variable_table: ".localvariabletable" EOL local_variable* ".end" "localvariabletable"
local_variable: u16 "is" utfref utfref code_range EOL
local_variable_type_table: ".localvariabletypetable" EOL local_variable_type* ".end" "localvariabletypetable"
local_variable_type: u16 "is" utfref utfref code_range EOL
method_parameters: ".methodparameters" EOL method_parameter_item* ".end" "methodparameters"
method_parameter_item: utfref flag* EOL
runtime_annotations: ".runtime" visibility (normal_annotations | parameter_annotations | type_annotations) ".end" "runtime"
visibility: "visible" | "invisible"
normal_annotations: "annotations" EOL annotation_line*
parameter_annotations: "paramannotations" EOL parameter_annotation_line*
type_annotations: "typeannotations" EOL type_annotation_line*
signature: ".signature" utfref
source_debug_extension: ".sourcedebugextension" STRING_LITERAL
source_file: ".sourcefile" utfref
stack_map_table: ".stackmaptable"
Note: The content of a StackMapTable attribute is automatically filled in based on the stack directives in the enclosing code attribute. If this attributes contents are nonempty and the attribute isnt specified explicitly, one will be added implicitly. This means that you generally dont have to specify it. Its only useful if you care about the exact binary layout of the classfile.
Note: The StackMapTable attribute depends entirely on the .stack directives specified. Krakatau will not calculate a new stack map for you from bytecode that does not have any stack information. If you want to do this, you should try using ASM.
synthetic: ".synthetic"
Code
code_body: (instruction_line | code_directive)* attribute*
code_directive: catch_directive | stack_directive | ".noimplicitstackmap"
catch_directive: ".catch" clsref code_range "using" label EOL
code_range: "from" label "to" label
stack_directive: ".stack" stackmapitem EOL
stackmapitem: stackmapitem_simple | stackmapitem_stack1 | stackmapitem_append | stackmapitem_full
stackmapitem_simple: "same" | "same_extended" | "chop" INT_LITERAL
stackmapitem_stack1: ("stack_1" | "stack_1_extended") verification_type
stackmapitem_append: "append" vt1to3
vt1to3: verification_type verification_type? verification_type?
stackmapitem_full: "full" EOL "locals" vtlist "stack" vtlist ".end" "stack"
vtlist: verification_type* EOL
verification_type: "Top" | "Integer" | "Float" | "Double" | "Long" | "Null" | "UninitializedThis" | "Object" clsref | "Uninitialized" label
instruction_line: (LBLDEF | LBLDEF? instruction) EOL
instruction: simple_instruction | complex_instruction
simple_instruction: op_none | op_short u8 | op_iinc u8 s8 | op_bipush s8 | op_sipush s16 | op_lbl label | op_fmim fmimref | on_invint fmimref u8? | op_invdyn invdynref | op_cls clsref | op_cls_int clsref u8 | op_ldc ldc_rhs
op_none: "nop" | "aconst_null" | "iconst_m1" | "iconst_0" | "iconst_1" | "iconst_2" | "iconst_3" | "iconst_4" | "iconst_5" | "lconst_0" | "lconst_1" | "fconst_0" | "fconst_1" | "fconst_2" | "dconst_0" | "dconst_1" | "iload_0" | "iload_1" | "iload_2" | "iload_3" | "lload_0" | "lload_1" | "lload_2" | "lload_3" | "fload_0" | "fload_1" | "fload_2" | "fload_3" | "dload_0" | "dload_1" | "dload_2" | "dload_3" | "aload_0" | "aload_1" | "aload_2" | "aload_3" | "iaload" | "laload" | "faload" | "daload" | "aaload" | "baload" | "caload" | "saload" | "istore_0" | "istore_1" | "istore_2" | "istore_3" | "lstore_0" | "lstore_1" | "lstore_2" | "lstore_3" | "fstore_0" | "fstore_1" | "fstore_2" | "fstore_3" | "dstore_0" | "dstore_1" | "dstore_2" | "dstore_3" | "astore_0" | "astore_1" | "astore_2" | "astore_3" | "iastore" | "lastore" | "fastore" | "dastore" | "aastore" | "bastore" | "castore" | "sastore" | "pop" | "pop2" | "dup" | "dup_x1" | "dup_x2" | "dup2" | "dup2_x1" | "dup2_x2" | "swap" | "iadd" | "ladd" | "fadd" | "dadd" | "isub" | "lsub" | "fsub" | "dsub" | "imul" | "lmul" | "fmul" | "dmul" | "idiv" | "ldiv" | "fdiv" | "ddiv" | "irem" | "lrem" | "frem" | "drem" | "ineg" | "lneg" | "fneg" | "dneg" | "ishl" | "lshl" | "ishr" | "lshr" | "iushr" | "lushr" | "iand" | "land" | "ior" | "lor" | "ixor" | "lxor" | "i2l" | "i2f" | "i2d" | "l2i" | "l2f" | "l2d" | "f2i" | "f2l" | "f2d" | "d2i" | "d2l" | "d2f" | "i2b" | "i2c" | "i2s" | "lcmp" | "fcmpl" | "fcmpg" | "dcmpl" | "dcmpg" | "ireturn" | "lreturn" | "freturn" | "dreturn" | "areturn" | "return" | "arraylength" | "athrow" | "monitorenter" | "monitorexit"
op_short: "iload" | "lload" | "fload" | "dload" | "aload" | "istore" | "lstore" | "fstore" | "dstore" | "astore" | "ret"
op_iinc: "iinc"
op_bipush: "bipush"
op_sipush: "sipush"
op_lbl: "ifeq" | "ifne" | "iflt" | "ifge" | "ifgt" | "ifle" | "if_icmpeq" | "if_icmpne" | "if_icmplt" | "if_icmpge" | "if_icmpgt" | "if_icmple" | "if_acmpeq" | "if_acmpne" | "goto" | "jsr" | "ifnull" | "ifnonnull" | "goto_w" | "jsr_w"
op_fmim: "getstatic" | "putstatic" | "getfield" | "putfield" | "invokevirtual" | "invokespecial" | "invokestatic"
on_invint: "invokeinterface"
op_invdyn: "invokedynamic"
op_cls: "new" | "anewarray" | "checkcast" | "instanceof"
op_cls_int: "multianewarray"
op_ldc: "ldc" | "ldc_w" | "ldc2_w"
complex_instruction: ins_newarr | ins_lookupswitch | ins_tableswitch | ins_wide
ins_newarr: "newarray" nacode
nacode: "boolean" | "char" | "float" | "double" | "byte" | "short" | "int" | "long"
ins_lookupswitch: "lookupswitch" EOL luentry* defaultentry
luentry: s32 ":" label EOL
defaultentry: "default:" label
ins_tableswitch: "tableswitch" s32 EOL tblentry* defaultentry
tblentry: label EOL
ins_wide: "wide" (op_short u16 | op_iinc u16 s16)
label: WORD
Annotations
element_value_line: element_value EOL
element_value: primtag ldc_rhs | "string" utfref | "class" utfref | "enum" utfref utfref | element_value_array | "annotation" annotation_contents annotation_end
primtag: "byte" | "char" | "double" | "int" | "float" | "long" | "short" | "boolean"
element_value_array: "array" EOL element_value_line* ".end" "array"
annotation_line: annotation EOL
annotation: ".annotation" annotation_contents annotation_end
annotation_contents: utfref key_ev_line*
key_ev_line: utfref "=" element_value_line
annotation_end: ".end" "annotation"
parameter_annotation_line: parameter_annotation EOL
parameter_annotation: ".paramannotation" EOL annotation_line* ".end" "paramannotation"
type_annotation_line: type_annotation EOL
type_annotation: ".typeannotation" u8 target_info EOL target_path EOL type_annotation_rest
target_info: type_parameter_target | supertype_target | type_parameter_bound_target | empty_target | method_formal_parameter_target | throws_target | localvar_target | catch_target | offset_target | type_argument_target
type_parameter_target: "typeparam" u8
supertype_target: "super" u16
type_parameter_bound_target: "typeparambound" u8 u8
empty_target: "empty"
method_formal_parameter_target: "methodparam" u8
throws_target: "throws" u16
localvar_target: "localvar" EOL localvarrange* ".end" "localvar"
localvarrange: (code_range | "nowhere") u16 EOL
catch_target: "catch" u16
offset_target: "offset" label
type_argument_target: "typearg" label u8
target_path: ".typepath" EOL type_path_segment* ".end" "typepath"
type_path_segment: u8 u8 EOL
type_annotation_rest: annotation_contents ".end" "typeannotation"

View File

@ -0,0 +1,12 @@
2018.01.26:
Add .noimplicitstackmap
2016.10.18: (Backwards incompatible)
String element values take utfref, rather than ldcrhs
2016.06.10:
Allow +, -, * to be used in WORDS, except at the beginning of a WORD. This allows all signature values to be specified without quotes.
Allow argument count of invokeinterface to be omitted when descriptor is specified inline.
2015.12.23: (Backwards incompatible)
Complete rewrite of the assembler. To many changes to list.

View File

@ -0,0 +1,246 @@
from .pool import Pool, utf
from .writer import Writer, Label
def writeU16Count(data, error, objects, message):
count = len(objects)
if count >= 1<<16:
error('Maximum {} count is {}, found {}.'.format(message, (1<<16)-1, count), objects[-1].tok)
data.u16(count)
class Code(object):
def __init__(self, tok, short):
self.tok = tok
self.short = short
self.locals = self.stack = 0
self.bytecode = Writer()
self.exceptions = Writer()
self.exceptcount = 0
self.stackdata = Writer()
self.stackcount = 0
self.stackcountpos = self.stackdata.ph16()
self.laststackoff = -1 # first frame doesn't subtract 1 from offset
self.stackmaptable = None
self.dont_generate_stackmap = False
self.attributes = []
self.labels = {}
self.maxcodelen = (1<<16 if short else 1<<32) - 1
def labeldef(self, lbl, error):
if lbl.sym in self.labels:
error('Duplicate label definition', lbl.tok,
'Previous definition here:', self.labels[lbl.sym][0])
self.labels[lbl.sym] = lbl.tok, self.bytecode.pos
def catch(self, ref, fromlbl, tolbl, usinglbl):
self.exceptcount += 1
self.exceptions.lbl(fromlbl, 0, 'u16')
self.exceptions.lbl(tolbl, 0, 'u16')
self.exceptions.lbl(usinglbl, 0, 'u16')
self.exceptions.ref(ref)
def assembleNoCP(self, data, error):
bytecode = self.bytecode
if self.short:
data.u8(self.stack), data.u8(self.locals), data.u16(len(bytecode))
else:
data.u16(self.stack), data.u16(self.locals), data.u32(len(bytecode))
data += bytecode
data.u16(self.exceptcount)
data += self.exceptions
if self.stackmaptable is None and self.stackcount > 0 and not self.dont_generate_stackmap:
# Use arbitrary token in case we need to report errors
self.stackmaptable = Attribute(self.tok, b'StackMapTable')
self.attributes.append(self.stackmaptable)
if self.stackmaptable:
self.stackdata.setph16(self.stackcountpos, self.stackcount)
self.stackmaptable.data = self.stackdata
writeU16Count(data, error, self.attributes, 'attribute')
for attr in self.attributes:
attr.assembleNoCP(data, error)
return data.fillLabels(self.labels, error)
class Attribute(object):
def __init__(self, tok, name, length=None):
assert tok
if isinstance(name, bytes):
name = utf(tok, name)
self.tok = tok
self.name = name
self.length = length
self.data = Writer()
def assembleNoCP(self, data, error):
length = len(self.data) if self.length is None else self.length
if length >= 1<<32:
error('Maximum attribute data length is {} bytes, got {} bytes.'.format((1<<32)-1, length), self.tok)
data.ref(self.name)
data.u32(length)
data += self.data
return data
class Method(object):
def __init__(self, tok, access, name, desc):
self.tok = tok
self.access = access
self.name = name
self.desc = desc
self.attributes = []
def assembleNoCP(self, data, error):
data.u16(self.access)
data.ref(self.name)
data.ref(self.desc)
writeU16Count(data, error, self.attributes, 'attribute')
for attr in self.attributes:
attr.assembleNoCP(data, error)
return data
class Field(object):
def __init__(self, tok, access, name, desc):
self.tok = tok
self.access = access
self.name = name
self.desc = desc
self.attributes = []
def assembleNoCP(self, data, error):
data.u16(self.access)
data.ref(self.name)
data.ref(self.desc)
writeU16Count(data, error, self.attributes, 'attribute')
for attr in self.attributes:
attr.assembleNoCP(data, error)
return data
class Class(object):
def __init__(self):
self.version = 49, 0
self.access = self.this = self.super = None
self.interfaces = []
self.fields = []
self.methods = []
self.attributes = []
self.useshortcodeattrs = False
self.bootstrapmethods = None
self.pool = Pool()
def _getName(self):
cpool = self.pool.cp
clsind = self.this.resolved_index
if not cpool.slots.get(clsind):
return None
if cpool.slots[clsind].type != 'Class':
return None
utfind = cpool.slots[clsind].refs[0].resolved_index
if utfind not in cpool.slots:
return None
return cpool.slots[utfind].data
def _assembleNoCP(self, error):
beforepool = Writer()
afterpool = Writer()
beforepool.u32(0xCAFEBABE)
beforepool.u16(self.version[1])
beforepool.u16(self.version[0])
afterpool.u16(self.access)
afterpool.ref(self.this)
afterpool.ref(self.super)
writeU16Count(afterpool, error, self.interfaces, 'interface')
for i in self.interfaces:
afterpool.ref(i)
writeU16Count(afterpool, error, self.fields, 'field')
for field in self.fields:
field.assembleNoCP(afterpool, error)
writeU16Count(afterpool, error, self.methods, 'method')
for method in self.methods:
method.assembleNoCP(afterpool, error)
attrcountpos = afterpool.ph16()
afterbs = Writer()
data = afterpool
for attr in self.attributes:
if attr is self.bootstrapmethods:
# skip writing this attr for now and switch to after bs stream
data = afterbs
else:
attr.assembleNoCP(data, error)
return beforepool, afterpool, afterbs, attrcountpos
def assemble(self, error):
beforepool, afterpool, afterbs, attrcountpos = self._assembleNoCP(error)
self.pool.cp.freezedefs(self.pool, error)
self.pool.bs.freezedefs(self.pool, error)
# afterpool is the only part that can contain ldcs
assert not beforepool.refu8phs
assert not afterbs.refu8phs
for _, ref in afterpool.refu8phs:
ind = ref.resolve(self.pool, error)
if ind >= 256:
error("Ldc references too many distinct constants in this class. If you don't want to see this message again, use ldc_w instead of ldc everywhere.", ref.tok)
beforepool.fillRefs(self.pool, error)
afterpool.fillRefs(self.pool, error)
afterbs.fillRefs(self.pool, error)
# Figure out if we need to add an implicit BootstrapMethods attribute
self.pool.resolveIDBSRefs(error)
if self.bootstrapmethods is None and self.pool.bs.slots:
assert len(afterbs) == 0
# Use arbitrary token in case we need to report errors
self.bootstrapmethods = Attribute(self.this.tok, b'BootstrapMethods')
self.attributes.append(self.bootstrapmethods)
if self.bootstrapmethods is not None:
self.bootstrapmethods.name.resolve(self.pool, error)
assert len(self.bootstrapmethods.data) == 0
if len(self.attributes) >= 1<<16:
error('Maximum class attribute count is 65535, found {}.'.format(count), self.attributes[-1].tok)
afterpool.setph16(attrcountpos, len(self.attributes))
cpdata, bsmdata = self.pool.write(error)
assert len(bsmdata) < (1 << 32)
data = beforepool
data += cpdata
data += afterpool
if self.bootstrapmethods is not None:
self.bootstrapmethods.data = bsmdata
self.bootstrapmethods.assembleNoCP(data, error)
data.fillRefs(self.pool, error)
data += afterbs
else:
assert len(afterbs) == 0
name = self._getName()
if name is None:
error('Invalid reference for class name.', self.this.tok)
return name, data.toBytes()

View File

@ -0,0 +1,12 @@
_handle_types = 'getField getStatic putField putStatic invokeVirtual invokeStatic invokeSpecial newInvokeSpecial invokeInterface'.split()
handle_codes = dict(zip(_handle_types, range(1,10)))
handle_rcodes = {v:k for k,v in handle_codes.items()}
newarr_rcodes = ([None]*4) + 'boolean char float double byte short int long'.split()
newarr_codes = dict(zip('boolean char float double byte short int long'.split(), range(4,12)))
vt_rcodes = ['Top','Integer','Float','Double','Long','Null','UninitializedThis','Object','Uninitialized']
vt_codes = {k:i for i,k in enumerate(vt_rcodes)}
et_rtags = dict(zip(map(ord, 'BCDFIJSZsec@['), 'byte char double float int long short boolean string enum class annotation array'.split()))
et_tags = {v:k for k,v in et_rtags.items()}

View File

@ -0,0 +1,843 @@
from __future__ import print_function
import collections
import math
import re
from ..classfileformat import classdata, mutf8
from ..classfileformat.reader import Reader, TruncatedStreamError
from ..util.thunk import thunk
from . import codes, token_regexes
from . import flags
from .instructions import OPNAMES, OP_CLS, OP_FMIM, OP_LBL, OP_NONE, OP_SHORT
MAX_INLINE_SIZE = 300
MAX_INDENT = 20
WORD_REGEX = re.compile(token_regexes.WORD + r'\Z')
PREFIXES = {'Utf8': 'u', 'Class': 'c', 'String': 's', 'Field': 'f', 'Method': 'm', 'InterfaceMethod': 'im', 'NameAndType': 'nat', 'MethodHandle': 'mh', 'MethodType': 'mt', 'InvokeDynamic': 'id'}
class DisassemblyError(Exception):
pass
def reprbytes(b):
# repr will have b in python3 but not python2
return 'b' + repr(b).lstrip('b')
def isword(s):
try:
s = s.decode('ascii')
except UnicodeDecodeError:
return False
return WORD_REGEX.match(s) and s not in flags.FLAGS
def format_string(s):
try:
u = mutf8.decode(s)
except UnicodeDecodeError:
print('Warning, invalid utf8 data!')
else:
if mutf8.encode(u) == s:
return repr(u).lstrip('u')
return reprbytes(s)
def make_signed(x, bits):
if x >= (1 << (bits - 1)):
x -= 1 << bits
return x
class StackMapReader(object):
def __init__(self):
self.stream = None
self.tag = -1
self.pos = -1
self.count = 0
self.valid = True
def setdata(self, r):
if self.stream is None:
self.stream = r
self.count = self.u16() + 1
self.parseNextPos()
else:
# Multiple StackMapTable attributes in same Code attribute
self.valid = False
def parseNextPos(self):
self.count -= 1
if self.count > 0:
self.tag = tag = self.u8()
if tag <= 127: # same and stack_1
delta = tag % 64
else: # everything else has 16bit delta field
delta = self.u16()
self.pos += delta + 1
def u8(self):
try:
return self.stream.u8()
except TruncatedStreamError:
self.valid = False
return 0
def u16(self):
try:
return self.stream.u16()
except TruncatedStreamError:
self.valid = False
return 0
class ReferencePrinter(object):
def __init__(self, clsdata, roundtrip):
self.roundtrip = roundtrip
self.cpslots = clsdata.pool.slots
for attr in clsdata.getattrs(b'BootstrapMethods'):
self.bsslots = classdata.BootstrapMethodsData(attr.stream()).slots
break
else:
self.bsslots = []
# CP index 0 should always be a raw reference. Additionally, there is one case where exact
# references are significant due to a bug in the JVM. In the InnerClasses attribute,
# specifying the same index for inner and outer class will fail verification, but specifying
# different indexes which point to identical class entries will pass. In this case, we force
# references to those indexes to be raw, so they don't get merged and break the class.
self.forcedraw = set()
for attr in clsdata.getattrs(b'InnerClasses'):
r = attr.stream()
for _ in range(r.u16()):
inner, outer, _, _ = r.u16(), r.u16(), r.u16(), r.u16()
if inner != outer and clsdata.pool.getclsutf(inner) == clsdata.pool.getclsutf(outer):
self.forcedraw.add(inner)
self.forcedraw.add(outer)
self.explicit_forcedraw = self.forcedraw.copy()
# For invalid cp indices, just output raw ref instead of throwing (including 0)
for i, slot in enumerate(self.cpslots):
if slot.tag is None:
self.forcedraw.add(i)
self.forcedraw.update(range(len(self.cpslots), 65536))
self.used = set()
self.encoded = {}
self.utfcounts = {}
def _float_or_double(self, x, nmbits, nebits, suffix, nanfmt):
nbits = nmbits + nebits + 1
assert nbits % 32 == 0
sbit, ebits, mbits = x >> (nbits - 1), (x >> nmbits) % (1 << nebits), x % (1 << nmbits)
if ebits == (1 << nebits) - 1:
result = 'NaN' if mbits else 'Infinity'
if self.roundtrip and mbits:
result += nanfmt.format(x)
elif ebits == 0 and mbits == 0:
result = '0.0'
else:
ebias = (1 << (nebits - 1)) - 1
exponent = ebits - ebias - nmbits
mantissa = mbits
if ebits > 0:
mantissa += 1 << nmbits
else:
exponent += 1
if self.roundtrip:
result = '0x{:X}p{}'.format(mantissa, exponent)
else:
result = repr(math.ldexp(mantissa, exponent))
return '+-'[sbit] + result + suffix
def _int(self, x): return str(make_signed(x, 32))
def _long(self, x): return str(make_signed(x, 64)) + 'L'
def _float(self, x): return self._float_or_double(x, 23, 8, 'f', '<0x{:08X}>')
def _double(self, x): return self._float_or_double(x, 52, 11, '', '<0x{:016X}>')
def _encode_utf(self, ind, wordok=True):
try:
return self.encoded[ind][wordok]
except KeyError:
s = self.cpslots[ind].data
string = format_string(s)
word = s.decode() if isword(s) else string
self.encoded[ind] = [string, word]
return word if wordok else string
def rawref(self, ind, isbs=False):
return '[{}{}]'.format('bs:' if isbs else '', ind)
def symref(self, ind, isbs=False):
self.used.add((ind, isbs))
if isbs:
return '[bs:_{}]'.format(ind)
prefix = PREFIXES.get(self.cpslots[ind].tag, '_')
return '[{}{}]'.format(prefix, ind)
def ref(self, ind, isbs=False):
if self.roundtrip or not isbs and ind in self.forcedraw:
return self.rawref(ind, isbs)
return self.symref(ind, isbs)
def _ident(self, ind, wordok=True):
if self.cpslots[ind].tag == 'Utf8':
val = self._encode_utf(ind, wordok=wordok)
if len(val) < MAX_INLINE_SIZE:
if len(val) < 50 or self.utfcounts.get(ind, 0) < 10:
self.utfcounts[ind] = 1 + self.utfcounts.get(ind, 0)
return val
def utfref(self, ind):
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
temp = self._ident(ind)
if temp is not None:
return temp
return self.symref(ind)
def clsref(self, ind, tag='Class'):
assert tag in 'Class Module Package'.split()
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
if self.cpslots[ind].tag == tag:
ind2 = self.cpslots[ind].refs[0]
temp = self._ident(ind2)
if temp is not None:
return temp
return self.symref(ind)
def natref(self, ind):
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
if self.cpslots[ind].tag == 'NameAndType':
ind2, ind3 = self.cpslots[ind].refs
temp = self._ident(ind2)
if temp is not None:
return temp + ' ' + self.utfref(ind3)
return self.symref(ind)
def fmimref(self, ind):
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
if self.cpslots[ind].tag in ['Field', 'Method', 'InterfaceMethod']:
ind2, ind3 = self.cpslots[ind].refs
return ' '.join([self.cpslots[ind].tag, self.clsref(ind2), self.natref(ind3)])
return self.symref(ind)
def mhnotref(self, ind):
slot = self.cpslots[ind]
return codes.handle_rcodes[slot.data] + ' ' + self.taggedref(slot.refs[0], allowed=['Field', 'Method', 'InterfaceMethod'])
def taggedconst(self, ind):
slot = self.cpslots[ind]
if slot.tag == 'Utf8':
parts = [self._encode_utf(ind)]
elif slot.tag == 'Int':
parts = [self._int(slot.data)]
elif slot.tag == 'Float':
parts = [self._float(slot.data)]
elif slot.tag == 'Long':
parts = [self._long(slot.data)]
elif slot.tag == 'Double':
parts = [self._double(slot.data)]
elif slot.tag in ['Class', 'String', 'MethodType', 'Module', 'Package']:
parts = [self.utfref(slot.refs[0])]
elif slot.tag in ['Field', 'Method', 'InterfaceMethod']:
parts = [self.clsref(slot.refs[0]), self.natref(slot.refs[1])]
elif slot.tag == 'NameAndType':
parts = [self.utfref(slot.refs[0]), self.utfref(slot.refs[1])]
elif slot.tag == 'MethodHandle':
parts = [self.mhnotref(ind)]
elif slot.tag == 'InvokeDynamic':
parts = [self.bsref(slot.refs[0]), self.natref(slot.refs[1])]
parts.insert(0, slot.tag)
return ' '.join(parts)
def taggedref(self, ind, allowed=None):
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
if allowed is None or self.cpslots[ind].tag in allowed:
temp = self.taggedconst(ind)
if len(temp) < MAX_INLINE_SIZE:
return temp
return self.symref(ind)
def ldcrhs(self, ind):
if self.roundtrip or ind in self.forcedraw:
return self.rawref(ind)
slot = self.cpslots[ind]
t = slot.tag
if t == 'Int':
return self._int(slot.data)
elif slot.tag == 'Float':
return self._float(slot.data)
elif slot.tag == 'Long':
return self._long(slot.data)
elif slot.tag == 'Double':
return self._double(slot.data)
elif t == 'String':
ind2 = self.cpslots[ind].refs[0]
temp = self._ident(ind2, wordok=False)
if temp is not None:
return temp
return self.symref(ind)
return self.taggedref(ind, allowed=['Class', 'MethodHandle', 'MethodType'])
def bsnotref(self, ind, tagged=False):
slot = self.bsslots[ind]
parts = []
if tagged:
parts.append('Bootstrap')
if tagged and self.roundtrip:
parts.append(self.rawref(slot.refs[0]))
else:
parts.append(self.mhnotref(slot.refs[0]))
for bsarg in slot.refs[1:]:
parts.append(self.taggedref(bsarg))
parts.append(':')
return ' '.join(parts)
def bsref(self, ind):
if self.roundtrip:
return self.rawref(ind, isbs=True)
return self.bsnotref(ind)
LabelInfos = collections.namedtuple('LabelInfos', 'defined used')
class Disassembler(object):
def __init__(self, clsdata, out, roundtrip):
self.roundtrip = roundtrip
self.out = out
self.cls = clsdata
self.pool = clsdata.pool
self.indentlevel = 0
self.labels = None
self.refprinter = ReferencePrinter(clsdata, roundtrip)
def _getattr(a, obj, name):
for attr in obj.attributes:
if a.pool.getutf(attr.name) == name:
return attr
def sol(a, text=''):
level = min(a.indentlevel, MAX_INDENT) * 4
text += ' ' * (level - len(text))
a.out(text)
def eol(a): a.out('\n')
def val(a, s): a.out(s + ' ')
def int(a, x): a.val(str(x))
def lbl(a, x):
a.labels.used.add(x)
a.val('L{}'.format(x))
def try_lbl(a, x):
if a.labels is None or x not in a.labels.defined:
raise DisassemblyError()
a.lbl(x)
###########################################################################
def extrablankline(a): a.eol()
def ref(a, ind, isbs=False): a.val(a.refprinter.ref(ind, isbs))
def utfref(a, ind): a.val(a.refprinter.utfref(ind))
def clsref(a, ind, tag='Class'): a.val(a.refprinter.clsref(ind, tag))
def natref(a, ind): a.val(a.refprinter.natref(ind))
def fmimref(a, ind): a.val(a.refprinter.fmimref(ind))
def taggedbs(a, ind): a.val(a.refprinter.bsnotref(ind, tagged=True))
def taggedconst(a, ind): a.val(a.refprinter.taggedconst(ind))
def taggedref(a, ind): a.val(a.refprinter.taggedref(ind))
def ldcrhs(a, ind): a.val(a.refprinter.ldcrhs(ind))
def flags(a, access, names):
for i in range(16):
if access & (1 << i):
a.val(names[1 << i])
###########################################################################
### Top level stuff (class, const defs, fields, methods) ##################
def disassemble(a):
cls = a.cls
a.val('.version'), a.int(cls.version[0]), a.int(cls.version[1]), a.eol()
a.val('.class'), a.flags(cls.access, flags.RFLAGS_CLASS), a.clsref(cls.this), a.eol()
a.val('.super'), a.clsref(cls.super), a.eol()
for ref in cls.interfaces:
a.val('.implements'), a.clsref(ref), a.eol()
for f in cls.fields:
a.field(f)
for m in cls.methods:
a.method(m)
for attr in cls.attributes:
a.attribute(attr)
a.constdefs()
a.val('.end class'), a.eol()
def field(a, f):
a.val('.field'), a.flags(f.access, flags.RFLAGS_FIELD), a.utfref(f.name), a.utfref(f.desc)
attrs = f.attributes[:]
cvattr = a._getattr(f, b'ConstantValue')
if cvattr and not a.roundtrip:
a.val('='), a.ldcrhs(cvattr.stream().u16())
attrs.remove(cvattr)
if attrs:
a.val('.fieldattributes'), a.eol()
a.indentlevel += 1
for attr in attrs:
a.attribute(attr)
a.indentlevel -= 1
a.val('.end fieldattributes')
a.eol()
def method(a, m):
a.extrablankline()
a.val('.method'), a.flags(m.access, flags.RFLAGS_METHOD), a.utfref(m.name), a.val(':'), a.utfref(m.desc), a.eol()
a.indentlevel += 1
for attr in m.attributes:
a.attribute(attr, in_method=True)
a.indentlevel -= 1
a.val('.end method'), a.eol()
def constdefs(a):
if a.roundtrip:
for ind in range(len(a.refprinter.cpslots)):
a.constdef(ind, False)
for ind in range(len(a.refprinter.bsslots)):
a.constdef(ind, True)
else:
assert not a.refprinter.used & a.refprinter.forcedraw
for ind in sorted(a.refprinter.explicit_forcedraw):
a.constdef(ind, False)
done = set()
while len(done) < len(a.refprinter.used):
for ind, isbs in sorted(a.refprinter.used - done):
a.constdef(ind, isbs)
done.add((ind, isbs))
def constdef(a, ind, isbs):
if not isbs and a.refprinter.cpslots[ind].tag is None:
return
a.sol(), a.val('.bootstrap' if isbs else '.const'), a.ref(ind, isbs), a.val('=')
if isbs:
a.taggedbs(ind)
else:
a.taggedconst(ind)
a.eol()
###########################################################################
### Bytecode ##############################################################
def code(a, r):
c = classdata.CodeData(r, a.pool, a.cls.version < (45, 3))
a.val('.code'), a.val('stack'), a.int(c.stack), a.val('locals'), a.int(c.locals), a.eol()
a.indentlevel += 1
assert a.labels is None
a.labels = LabelInfos(set(), set())
stackreader = StackMapReader()
for attr in c.attributes:
if a.pool.getutf(attr.name) == b'StackMapTable':
stackreader.setdata(attr.stream())
rexcepts = c.exceptions[::-1]
bcreader = Reader(c.bytecode)
while bcreader.size():
a.insline_start(bcreader.off, rexcepts, stackreader)
a.instruction(bcreader)
a.insline_start(bcreader.off, rexcepts, stackreader), a.eol()
badlbls = a.labels.used - a.labels.defined
if badlbls:
stackreader.valid = False
a.sol('; Labels used by invalid StackMapTable attribute'), a.eol()
for pos in sorted(badlbls):
a.sol('L{}'.format(pos) + ':'), a.eol()
if stackreader.stream and (stackreader.stream.size() or stackreader.count > 0):
stackreader.valid = False
if not stackreader.valid:
a.sol('.noimplicitstackmap'), a.eol()
for attr in c.attributes:
a.attribute(attr, use_raw_stackmap=not stackreader.valid)
a.labels = None
a.indentlevel -= 1
a.sol(), a.val('.end code')
def insline_start(a, pos, rexcepts, stackreader):
while rexcepts and rexcepts[-1].start <= pos:
e = rexcepts.pop()
a.sol(), a.val('.catch'), a.clsref(e.type), a.val('from'), a.lbl(e.start)
a.val('to'), a.lbl(e.end), a.val('using'), a.lbl(e.handler), a.eol()
if stackreader.count > 0 and stackreader.pos == pos:
r = stackreader
tag = stackreader.tag
a.extrablankline()
a.sol(), a.val('.stack')
if tag <= 63:
a.val('same')
elif tag <= 127:
a.val('stack_1'), a.verification_type(r)
elif tag == 247:
a.val('stack_1_extended'), a.verification_type(r)
elif tag < 251:
a.val('chop'), a.int(251 - tag)
elif tag == 251:
a.val('same_extended')
elif tag < 255:
a.val('append')
for _ in range(tag - 251):
a.verification_type(r)
else:
a.val('full')
a.indentlevel += 1
a.eol(), a.sol(), a.val('locals')
for _ in range(r.u16()):
a.verification_type(r)
a.eol(), a.sol(), a.val('stack')
for _ in range(r.u16()):
a.verification_type(r)
a.indentlevel -= 1
a.eol(), a.sol(), a.val('.end stack')
a.eol()
stackreader.parseNextPos()
a.sol('L{}'.format(pos) + ':')
a.labels.defined.add(pos)
def verification_type(a, r):
try:
tag = codes.vt_rcodes[r.u8()]
except IndexError:
r.valid = False
a.val('Top')
return
a.val(tag)
if tag == 'Object':
a.clsref(r.u16())
elif tag == 'Uninitialized':
a.lbl(r.u16())
def instruction(a, r):
pos = r.off
op = OPNAMES[r.u8()]
a.val(op)
if op in OP_LBL:
a.lbl(pos + (r.s32() if op.endswith('_w') else r.s16()))
elif op in OP_SHORT:
a.int(r.u8())
elif op in OP_CLS:
a.clsref(r.u16())
elif op in OP_FMIM:
a.fmimref(r.u16())
elif op == 'invokeinterface':
a.fmimref(r.u16()), a.int(r.u8()), r.u8()
elif op == 'invokedynamic':
a.taggedref(r.u16()), r.u16()
elif op in ['ldc', 'ldc_w', 'ldc2_w']:
a.ldcrhs(r.u8() if op == 'ldc' else r.u16())
elif op == 'multianewarray':
a.clsref(r.u16()), a.int(r.u8())
elif op == 'bipush':
a.int(r.s8())
elif op == 'sipush':
a.int(r.s16())
elif op == 'iinc':
a.int(r.u8()), a.int(r.s8())
elif op == 'wide':
op2 = OPNAMES[r.u8()]
a.val(op2), a.int(r.u16())
if op2 == 'iinc':
a.int(r.s16())
elif op == 'newarray':
a.val(codes.newarr_rcodes[r.u8()])
elif op == 'tableswitch':
r.getRaw((3-pos) % 4)
default = pos + r.s32()
low, high = r.s32(), r.s32()
a.int(low), a.eol()
a.indentlevel += 1
for _ in range(high - low + 1):
a.sol(), a.lbl(pos + r.s32()), a.eol()
a.sol(), a.val('default'), a.val(':'), a.lbl(default), a.eol()
a.indentlevel -= 1
elif op == 'lookupswitch':
r.getRaw((3-pos) % 4)
default = pos + r.s32()
a.eol()
a.indentlevel += 1
for _ in range(r.s32()):
a.sol(), a.int(r.s32()), a.val(':'), a.lbl(pos + r.s32()), a.eol()
a.sol(), a.val('default'), a.val(':'), a.lbl(default), a.eol()
a.indentlevel -= 1
else:
assert op in OP_NONE
a.eol()
###########################################################################
### Attributes ############################################################
def attribute(a, attr, in_method=False, use_raw_stackmap=False):
name = a.pool.getutf(attr.name)
if not a.roundtrip and name in (b'BootstrapMethods', b'StackMapTable'):
return
# a.extrablankline()
a.sol()
isnamed = False
if a.roundtrip or name is None:
isnamed = True
a.val('.attribute'), a.utfref(attr.name)
if attr.wronglength:
a.val('length'), a.int(attr.length)
if name == b'Code' and in_method:
a.code(attr.stream())
elif name == b'BootstrapMethods' and a.cls.version >= (51, 0):
a.val('.bootstrapmethods')
elif name == b'StackMapTable' and not use_raw_stackmap:
a.val('.stackmaptable')
elif a.attribute_fallible(name, attr):
pass
else:
print('Nonstandard attribute', name[:70], len(attr.raw))
if not isnamed:
a.val('.attribute'), a.utfref(attr.name)
a.val(reprbytes(attr.raw))
a.eol()
def attribute_fallible(a, name, attr):
# Temporarily buffer output so we don't get partial output if attribute disassembly fails
# in case of failure, we'll fall back to binary output in the caller
orig_out = a.out
buffer_ = []
a.out = buffer_.append
try:
r = attr.stream()
if name == b'AnnotationDefault':
a.val('.annotationdefault'), a.element_value(r)
elif name == b'ConstantValue':
a.val('.constantvalue'), a.ldcrhs(r.u16())
elif name == b'Deprecated':
a.val('.deprecated')
elif name == b'EnclosingMethod':
a.val('.enclosing method'), a.clsref(r.u16()), a.natref(r.u16())
elif name == b'Exceptions':
a.val('.exceptions')
for _ in range(r.u16()):
a.clsref(r.u16())
elif name == b'InnerClasses':
a.indented_line_list(r, a._innerclasses_item, 'innerclasses')
elif name == b'LineNumberTable':
a.indented_line_list(r, a._linenumber_item, 'linenumbertable')
elif name == b'LocalVariableTable':
a.indented_line_list(r, a._localvariabletable_item, 'localvariabletable')
elif name == b'LocalVariableTypeTable':
a.indented_line_list(r, a._localvariabletable_item, 'localvariabletypetable')
elif name == b'MethodParameters':
a.indented_line_list(r, a._methodparams_item, 'methodparameters', bytelen=True)
elif name == b'Module':
a.module_attr(r)
elif name == b'ModuleMainClass':
a.val('.modulemainclass'), a.clsref(r.u16())
elif name == b'ModulePackages':
a.val('.modulepackages')
for _ in range(r.u16()):
a.clsref(r.u16(), tag='Package')
elif name in (b'RuntimeVisibleAnnotations', b'RuntimeVisibleParameterAnnotations',
b'RuntimeVisibleTypeAnnotations', b'RuntimeInvisibleAnnotations',
b'RuntimeInvisibleParameterAnnotations', b'RuntimeInvisibleTypeAnnotations'):
a.val('.runtime')
a.val('invisible' if b'Inv' in name else 'visible')
if b'Type' in name:
a.val('typeannotations'), a.eol()
a.indented_line_list(r, a.type_annotation_line, 'runtime', False)
elif b'Parameter' in name:
a.val('paramannotations'), a.eol()
a.indented_line_list(r, a.param_annotation_line, 'runtime', False, bytelen=True)
else:
a.val('annotations'), a.eol()
a.indented_line_list(r, a.annotation_line, 'runtime', False)
elif name == b'Signature':
a.val('.signature'), a.utfref(r.u16())
elif name == b'SourceDebugExtension':
a.val('.sourcedebugextension')
a.val(reprbytes(attr.raw))
elif name == b'SourceFile':
a.val('.sourcefile'), a.utfref(r.u16())
elif name == b'Synthetic':
a.val('.synthetic')
# check for extra data in the attribute
if r.size():
raise DisassemblyError()
except (TruncatedStreamError, DisassemblyError):
a.out = orig_out
return False
a.out = orig_out
a.out(''.join(buffer_))
return True
def module_attr(a, r):
a.val('.module'), a.clsref(r.u16(), tag='Module'), a.flags(r.u16(), flags.RFLAGS_MOD_OTHER)
a.val('version'), a.utfref(r.u16()), a.eol()
a.indentlevel += 1
for _ in range(r.u16()):
a.sol(), a.val('.requires'), a.clsref(r.u16(), tag='Module'), a.flags(r.u16(), flags.RFLAGS_MOD_REQUIRES), a.val('version'), a.utfref(r.u16()), a.eol()
for dir_ in ('.exports', '.opens'):
for _ in range(r.u16()):
a.sol(), a.val(dir_), a.clsref(r.u16(), tag='Package'), a.flags(r.u16(), flags.RFLAGS_MOD_OTHER)
count = r.u16()
if count:
a.val('to')
for _ in range(count):
a.clsref(r.u16(), tag='Module')
a.eol()
for _ in range(r.u16()):
a.sol(), a.val('.uses'), a.clsref(r.u16()), a.eol()
for _ in range(r.u16()):
a.sol(), a.val('.provides'), a.clsref(r.u16()), a.val('with')
for _ in range(r.u16()):
a.clsref(r.u16())
a.eol()
a.indentlevel -= 1
a.sol(), a.val('.end module')
def indented_line_list(a, r, cb, dirname, dostart=True, bytelen=False):
if dostart:
a.val('.' + dirname), a.eol()
a.indentlevel += 1
for _ in range(r.u8() if bytelen else r.u16()):
a.sol(), cb(r), a.eol()
a.indentlevel -= 1
if dirname is not None:
a.sol(), a.val('.end ' + dirname)
def _innerclasses_item(a, r): a.clsref(r.u16()), a.clsref(r.u16()), a.utfref(r.u16()), a.flags(r.u16(), flags.RFLAGS_CLASS)
def _linenumber_item(a, r): a.try_lbl(r.u16()), a.int(r.u16())
def _localvariabletable_item(a, r):
start, length, name, desc, ind = r.u16(), r.u16(), r.u16(), r.u16(), r.u16()
a.int(ind), a.val('is'), a.utfref(name), a.utfref(desc),
a.val('from'), a.try_lbl(start), a.val('to'), a.try_lbl(start + length)
def _methodparams_item(a, r): a.utfref(r.u16()), a.flags(r.u16(), flags.RFLAGS_MOD_OTHER)
###########################################################################
### Annotations ###########################################################
def annotation_line(a, r):
a.val('.annotation'), a.annotation_contents(r), a.sol(), a.val('.end'), a.val('annotation')
def param_annotation_line(a, r):
a.indented_line_list(r, a.annotation_line, 'paramannotation')
def type_annotation_line(a, r):
a.val('.typeannotation')
a.indentlevel += 1
a.ta_target_info(r) # Note: begins on same line as .typeannotation
a.ta_target_path(r)
a.sol(), a.annotation_contents(r),
a.indentlevel -= 1
a.sol(), a.val('.end'), a.val('typeannotation')
def ta_target_info(a, r):
tag = r.u8()
a.int(tag)
if tag <= 0x01:
a.val('typeparam'), a.int(r.u8())
elif tag <= 0x10:
a.val('super'), a.int(r.u16())
elif tag <= 0x12:
a.val('typeparambound'), a.int(r.u8()), a.int(r.u8())
elif tag <= 0x15:
a.val('empty')
elif tag <= 0x16:
a.val('methodparam'), a.int(r.u8())
elif tag <= 0x17:
a.val('throws'), a.int(r.u16())
elif tag <= 0x41:
a.val('localvar'), a.eol()
a.indented_line_list(r, a._localvarrange, 'localvar', False)
elif tag <= 0x42:
a.val('catch'), a.int(r.u16())
elif tag <= 0x46:
a.val('offset'), a.try_lbl(r.u16())
else:
a.val('typearg'), a.try_lbl(r.u16()), a.int(r.u8())
a.eol()
def _localvarrange(a, r):
start, length, index = r.u16(), r.u16(), r.u16()
if start == length == 0xFFFF: # WTF, Java?
a.val('nowhere')
else:
a.val('from'), a.try_lbl(start), a.val('to'), a.try_lbl(start + length)
a.int(index)
def ta_target_path(a, r):
a.sol(), a.indented_line_list(r, a._type_path_segment, 'typepath', bytelen=True), a.eol()
def _type_path_segment(a, r):
a.int(r.u8()), a.int(r.u8())
# The following are recursive and can be nested arbitrarily deep,
# so we use generators and a thunk to avoid the Python stack limit.
def element_value(a, r): thunk(a._element_value(r))
def annotation_contents(a, r): thunk(a._annotation_contents(r))
def _element_value(a, r):
tag = codes.et_rtags.get(r.u8())
if tag is None:
raise DisassemblyError()
a.val(tag)
if tag == 'annotation':
(yield a._annotation_contents(r)), a.sol(), a.val('.end'), a.val('annotation')
elif tag == 'array':
a.eol()
a.indentlevel += 1
for _ in range(r.u16()):
a.sol(), (yield a._element_value(r)), a.eol()
a.indentlevel -= 1
a.sol(), a.val('.end'), a.val('array')
elif tag == 'enum':
a.utfref(r.u16()), a.utfref(r.u16())
elif tag == 'class' or tag == 'string':
a.utfref(r.u16())
else:
a.ldcrhs(r.u16())
def _annotation_contents(a, r):
a.utfref(r.u16()), a.eol()
a.indentlevel += 1
for _ in range(r.u16()):
a.sol(), a.utfref(r.u16()), a.val('='), (yield a._element_value(r)), a.eol()
a.indentlevel -= 1

View File

@ -0,0 +1,43 @@
_pairs = [
('public', 0x0001),
('private', 0x0002),
('protected', 0x0004),
('static', 0x0008),
('final', 0x0010),
('super', 0x0020),
# todo - order attributes properly by context
('transitive', 0x0020),
('open', 0x0020),
('synchronized', 0x0020),
('volatile', 0x0040),
('static_phase', 0x0040),
('bridge', 0x0040),
('transient', 0x0080),
('varargs', 0x0080),
('native', 0x0100),
('interface', 0x0200),
('abstract', 0x0400),
('strict', 0x0800),
('synthetic', 0x1000),
('annotation', 0x2000),
('enum', 0x4000),
('module', 0x8000),
('mandated', 0x8000),
]
FLAGS = dict(_pairs)
# Treat strictfp as flag too to reduce confusion
FLAGS['strictfp'] = FLAGS['strict']
def _make_dict(priority):
d = {v:k for k,v in reversed(_pairs)}
# ensure that the specified flags have priority
for flag in priority.split():
d[FLAGS[flag]] = flag
return d
RFLAGS_CLASS = _make_dict('super module')
RFLAGS_FIELD = _make_dict('volatile transient')
RFLAGS_METHOD = _make_dict('synchronized bridge varargs')
RFLAGS_MOD_REQUIRES = _make_dict('transitive static_phase mandated')
RFLAGS_MOD_OTHER = _make_dict('open mandated')

View File

@ -0,0 +1,18 @@
OP_NONE = frozenset(['nop', 'aconst_null', 'iconst_m1', 'iconst_0', 'iconst_1', 'iconst_2', 'iconst_3', 'iconst_4', 'iconst_5', 'lconst_0', 'lconst_1', 'fconst_0', 'fconst_1', 'fconst_2', 'dconst_0', 'dconst_1', 'iload_0', 'iload_1', 'iload_2', 'iload_3', 'lload_0', 'lload_1', 'lload_2', 'lload_3', 'fload_0', 'fload_1', 'fload_2', 'fload_3', 'dload_0', 'dload_1', 'dload_2', 'dload_3', 'aload_0', 'aload_1', 'aload_2', 'aload_3', 'iaload', 'laload', 'faload', 'daload', 'aaload', 'baload', 'caload', 'saload', 'istore_0', 'istore_1', 'istore_2', 'istore_3', 'lstore_0', 'lstore_1', 'lstore_2', 'lstore_3', 'fstore_0', 'fstore_1', 'fstore_2', 'fstore_3', 'dstore_0', 'dstore_1', 'dstore_2', 'dstore_3', 'astore_0', 'astore_1', 'astore_2', 'astore_3', 'iastore', 'lastore', 'fastore', 'dastore', 'aastore', 'bastore', 'castore', 'sastore', 'pop', 'pop2', 'dup', 'dup_x1', 'dup_x2', 'dup2', 'dup2_x1', 'dup2_x2', 'swap', 'iadd', 'ladd', 'fadd', 'dadd', 'isub', 'lsub', 'fsub', 'dsub', 'imul', 'lmul', 'fmul', 'dmul', 'idiv', 'ldiv', 'fdiv', 'ddiv', 'irem', 'lrem', 'frem', 'drem', 'ineg', 'lneg', 'fneg', 'dneg', 'ishl', 'lshl', 'ishr', 'lshr', 'iushr', 'lushr', 'iand', 'land', 'ior', 'lor', 'ixor', 'lxor', 'i2l', 'i2f', 'i2d', 'l2i', 'l2f', 'l2d', 'f2i', 'f2l', 'f2d', 'd2i', 'd2l', 'd2f', 'i2b', 'i2c', 'i2s', 'lcmp', 'fcmpl', 'fcmpg', 'dcmpl', 'dcmpg', 'ireturn', 'lreturn', 'freturn', 'dreturn', 'areturn', 'return', 'arraylength', 'athrow', 'monitorenter', 'monitorexit'])
OP_SHORT = frozenset(['iload', 'lload', 'fload', 'dload', 'aload', 'istore', 'lstore', 'fstore', 'dstore', 'astore', 'ret'])
OP_LBL = frozenset(['ifeq', 'ifne', 'iflt', 'ifge', 'ifgt', 'ifle', 'if_icmpeq', 'if_icmpne', 'if_icmplt', 'if_icmpge', 'if_icmpgt', 'if_icmple', 'if_acmpeq', 'if_acmpne', 'goto', 'jsr', 'ifnull', 'ifnonnull', 'goto_w', 'jsr_w'])
OP_CLS = frozenset(['new', 'anewarray', 'checkcast', 'instanceof'])
OP_FMIM_TO_GUESS = {
'getstatic': 'Field',
'putstatic': 'Field',
'getfield': 'Field',
'putfield': 'Field',
'invokevirtual': 'Method',
'invokespecial': 'Method',
'invokestatic': 'Method',
}
OP_FMIM = frozenset(OP_FMIM_TO_GUESS)
OPNAMES = ('nop', 'aconst_null', 'iconst_m1', 'iconst_0', 'iconst_1', 'iconst_2', 'iconst_3', 'iconst_4', 'iconst_5', 'lconst_0', 'lconst_1', 'fconst_0', 'fconst_1', 'fconst_2', 'dconst_0', 'dconst_1', 'bipush', 'sipush', 'ldc', 'ldc_w', 'ldc2_w', 'iload', 'lload', 'fload', 'dload', 'aload', 'iload_0', 'iload_1', 'iload_2', 'iload_3', 'lload_0', 'lload_1', 'lload_2', 'lload_3', 'fload_0', 'fload_1', 'fload_2', 'fload_3', 'dload_0', 'dload_1', 'dload_2', 'dload_3', 'aload_0', 'aload_1', 'aload_2', 'aload_3', 'iaload', 'laload', 'faload', 'daload', 'aaload', 'baload', 'caload', 'saload', 'istore', 'lstore', 'fstore', 'dstore', 'astore', 'istore_0', 'istore_1', 'istore_2', 'istore_3', 'lstore_0', 'lstore_1', 'lstore_2', 'lstore_3', 'fstore_0', 'fstore_1', 'fstore_2', 'fstore_3', 'dstore_0', 'dstore_1', 'dstore_2', 'dstore_3', 'astore_0', 'astore_1', 'astore_2', 'astore_3', 'iastore', 'lastore', 'fastore', 'dastore', 'aastore', 'bastore', 'castore', 'sastore', 'pop', 'pop2', 'dup', 'dup_x1', 'dup_x2', 'dup2', 'dup2_x1', 'dup2_x2', 'swap', 'iadd', 'ladd', 'fadd', 'dadd', 'isub', 'lsub', 'fsub', 'dsub', 'imul', 'lmul', 'fmul', 'dmul', 'idiv', 'ldiv', 'fdiv', 'ddiv', 'irem', 'lrem', 'frem', 'drem', 'ineg', 'lneg', 'fneg', 'dneg', 'ishl', 'lshl', 'ishr', 'lshr', 'iushr', 'lushr', 'iand', 'land', 'ior', 'lor', 'ixor', 'lxor', 'iinc', 'i2l', 'i2f', 'i2d', 'l2i', 'l2f', 'l2d', 'f2i', 'f2l', 'f2d', 'd2i', 'd2l', 'd2f', 'i2b', 'i2c', 'i2s', 'lcmp', 'fcmpl', 'fcmpg', 'dcmpl', 'dcmpg', 'ifeq', 'ifne', 'iflt', 'ifge', 'ifgt', 'ifle', 'if_icmpeq', 'if_icmpne', 'if_icmplt', 'if_icmpge', 'if_icmpgt', 'if_icmple', 'if_acmpeq', 'if_acmpne', 'goto', 'jsr', 'ret', 'tableswitch', 'lookupswitch', 'ireturn', 'lreturn', 'freturn', 'dreturn', 'areturn', 'return', 'getstatic', 'putstatic', 'getfield', 'putfield', 'invokevirtual', 'invokespecial','invokestatic', 'invokeinterface', 'invokedynamic', 'new', 'newarray', 'anewarray', 'arraylength', 'athrow', 'checkcast', 'instanceof', 'monitorenter', 'monitorexit', 'wide', 'multianewarray', 'ifnull', 'ifnonnull', 'goto_w', 'jsr_w')
OPNAME_TO_BYTE = {v:i for i, v in enumerate(OPNAMES)}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,259 @@
import collections
import itertools
from .writer import Writer
TAGS = [None, 'Utf8', None, 'Int', 'Float', 'Long', 'Double', 'Class', 'String', 'Field', 'Method', 'InterfaceMethod', 'NameAndType', None, None, 'MethodHandle', 'MethodType', None, 'InvokeDynamic', 'Module', 'Package']
class Ref(object):
def __init__(self, tok, index=None, symbol=None, type=None, refs=None, data=None, isbs=False):
self.tok = tok
self.isbs = isbs
self.index = index
self.symbol = symbol
assert type == 'Bootstrap' or type in TAGS
self.type = type
self.refs = refs or []
self.data = data
self.resolved_index = None
def israw(self): return self.index is not None
def issym(self): return self.symbol is not None
def _deepdata(self, pool, error, defstack=()):
if self.issym():
return pool.sub(self).getroot(self, error)._deepdata(pool, error, defstack)
if self.israw():
return 'Raw', self.index
if len(defstack) > 5: # Maximum legitimate depth is 5: ID -> BS -> MH -> F -> NAT -> UTF
error_args = ['Constant pool definitions cannot be nested more than 5 deep (excluding raw references).', self.tok]
for ref in reversed(defstack):
error_args.append('Included from {} ref here:'.format(ref.type))
error_args.append(ref.tok)
error(*error_args)
return self.type, self.data, tuple(ref._deepdata(pool, error, defstack + (self,)) for ref in self.refs)
def _resolve(self, pool, error):
if self.israw():
return self.index
if self.issym():
return pool.sub(self).getroot(self, error).resolve(pool, error)
return pool.sub(self).resolvedata(self, error, self._deepdata(pool, error))
def resolve(self, pool, error):
if self.resolved_index is None:
self.resolved_index = self._resolve(pool, error)
assert self.resolved_index is not None
return self.resolved_index
def __str__(self): # pragma: no cover
prefix = 'bs:' if self.isbs else ''
if self.israw():
return '[{}{}]'.format(prefix, self.index)
elif self.issym():
return '[{}{}]'.format(prefix, self.symbol)
parts = [self.type] + self.refs
if self.data is not None:
parts.insert(1, self.data)
return ' '.join(map(str, parts))
def utf(tok, s):
assert isinstance(s, bytes)
assert len(s) <= 65535
return Ref(tok, type='Utf8', data=s)
def single(type, tok, s):
assert type in 'Class String MethodType Module Package'.split()
return Ref(tok, type=type, refs=[utf(tok, s)])
def nat(name, desc):
return Ref(name.tok, type='NameAndType', refs=[name, desc])
def primitive(type, tok, x):
assert type in 'Int Long Float Double'.split()
return Ref(tok, type=type, data=x)
class PoolSub(object):
def __init__(self, isbs):
self.isbs = isbs
self.symdefs = {}
self.symrootdefs = {}
self.slot_def_tokens = {}
self.slots = collections.OrderedDict()
self.dataToSlot = {}
self.narrowcounter = itertools.count()
self.widecounter = itertools.count()
self.dirtyslotdefs = []
self.defsfrozen = False
if not isbs:
self.slots[0] = None
def adddef(self, lhs, rhs, error):
assert not self.defsfrozen
assert lhs.israw() or lhs.issym()
if lhs.israw():
if lhs.index == 0 and not self.isbs:
error('Constant pool index must be nonzero', lhs.tok)
if lhs.index in self.slots:
error('Conflicting raw reference definition', lhs.tok,
'Conflicts with previous definition:', self.slot_def_tokens[lhs.index])
self.slots[lhs.index] = rhs
self.slot_def_tokens[lhs.index] = lhs.tok
self.dirtyslotdefs.append(lhs.index)
assert rhs.type
if rhs.type in ('Long', 'Double'):
if lhs.index + 1 in self.slots:
error('Conflicting raw reference definition', lhs.tok,
'Conflicts with previous definition:', self.slot_def_tokens[lhs.index + 1])
self.slots[lhs.index + 1] = None
self.slot_def_tokens[lhs.index + 1] = lhs.tok
else:
if lhs.symbol in self.symdefs:
error('Duplicate symbolic reference definition', lhs.tok,
'Previously defined here:', self.symdefs[lhs.symbol][0])
self.symdefs[lhs.symbol] = lhs.tok, rhs
def freezedefs(self, pool, error): self.defsfrozen = True
def _getslot(self, iswide):
assert self.defsfrozen
if iswide:
ind = next(self.widecounter)
while ind in self.slots or ind + 1 in self.slots:
ind = next(self.widecounter)
if ind + 1 >= 0xFFFF:
return None
else:
ind = next(self.narrowcounter)
while ind in self.slots:
ind = next(self.narrowcounter)
if ind >= 0xFFFF:
return None
return ind
def getroot(self, ref, error):
assert self.defsfrozen and ref.issym()
try:
return self.symrootdefs[ref.symbol]
except KeyError:
stack = []
visited = set()
while ref.issym():
sym = ref.symbol
if sym in visited:
error_args = ['Circular symbolic reference', ref.tok]
for tok in stack[::-1]:
error_args.extend(('Included from here:', tok))
error(*error_args)
stack.append(ref.tok)
visited.add(sym)
if sym not in self.symdefs:
error('Undefined symbolic reference', ref.tok)
_, ref = self.symdefs[sym]
for sym in visited:
self.symrootdefs[sym] = ref
return ref
def resolvedata(self, ref, error, newdata):
try:
return self.dataToSlot[newdata]
except KeyError:
iswide = newdata[0] in ('Long', 'Double')
slot = self._getslot(iswide)
if slot is None:
name = 'bootstrap method' if ref.isbs else 'constant pool'
error('Exhausted {} space.'.format(name), ref.tok)
self.dataToSlot[newdata] = slot
self.slots[slot] = ref
self.dirtyslotdefs.append(slot)
if iswide:
self.slots[slot + 1] = None
return slot
def resolveslotrefs(self, pool, error):
while len(self.dirtyslotdefs) > 0:
i = self.dirtyslotdefs.pop()
for ref in self.slots[i].refs:
ref.resolve(pool, error)
def writeconst(self, w, ref, pool, error):
t = ref.type
w.u8(TAGS.index(t))
if t == 'Utf8':
w.u16(len(ref.data))
w.writeBytes(ref.data)
elif t == 'Int' or t == 'Float':
w.u32(ref.data)
elif t == 'Long' or t == 'Double':
w.u64(ref.data)
elif t == 'MethodHandle':
w.u8(ref.data)
w.u16(ref.refs[0].resolve(pool, error))
else:
for child in ref.refs:
w.u16(child.resolve(pool, error))
return w
def writebootstrap(self, w, ref, pool, error):
assert ref.type == 'Bootstrap'
w.u16(ref.refs[0].resolve(pool, error))
w.u16(len(ref.refs)-1)
for child in ref.refs[1:]:
w.u16(child.resolve(pool, error))
return w
def write(self, pool, error):
self.resolveslotrefs(pool, error)
self.dirtyslotdefs = None # make sure we don't accidently add entries after size is taken
size = max(self.slots) + 1 if self.slots else 0
dummyentry = b'\1\0\0' # empty UTF8
if self.isbs and self.slots:
first = next(iter(self.slots.values()))
dummyentry = self.writebootstrap(Writer(), first, pool, error).toBytes()
w = Writer()
w.u16(size)
for i in range(size):
if i not in self.slots:
w.writeBytes(dummyentry)
continue
v = self.slots[i]
if v is None:
continue
if self.isbs:
self.writebootstrap(w, v, pool, error)
if len(w) >= (1<<32):
error('Maximum BootstrapMethods length is {} bytes.'.format((1<<32)-1), v.tok)
else:
self.writeconst(w, v, pool, error)
return w
class Pool(object):
def __init__(self):
self.cp = PoolSub(False)
self.bs = PoolSub(True)
def sub(self, ref): return self.bs if ref.isbs else self.cp
def resolveIDBSRefs(self, error):
for v in self.cp.slots.values():
if v is not None and v.type == 'InvokeDynamic':
v.refs[0].resolve(self, error)
def write(self, error):
bsmdata = self.bs.write(self, error)
cpdata = self.cp.write(self, error)
return cpdata, bsmdata

View File

@ -0,0 +1,45 @@
DIRECTIVE = r'\.[a-z]+'
WORD = r'(?:[a-zA-Z_$\(<]|\[[A-Z\[])[\w$;\/\[\(\)<>*+-]*'
FOLLOWED_BY_WHITESPACE = r'(?=\s|\Z)'
REF = r'\[[a-z0-9_:]+\]'
LABEL_DEF = r'L\w+:'
COMMENT = r';.*'
# Match optional comment and at least one newline, followed by any number of empty/whitespace lines
NEWLINES = r'(?:{})?\n\s*'.format(COMMENT)
HEX_DIGIT = r'[0-9a-fA-F]'
ESCAPE_SEQUENCE = r'''\\(?:U00(?:10|0{hd}){hd}{{4}}|u{hd}{{4}}|x{hd}{{2}}|[btnfr'"\\0-7])'''.format(hd=HEX_DIGIT)
# See http://stackoverflow.com/questions/430759/regex-for-managing-escaped-characters-for-items-like-string-literals/5455705# 5455705
STRING_LITERAL = r'''
[bB]?(?:
"
[^"\n\\]* # any number of unescaped characters
(?:{es}[^"\n\\]* # escape sequence followed by 0 or more unescaped
)*
"
|
'
[^'\n\\]* # any number of unescaped characters
(?:{es}[^'\n\\]* # escape sequence followed by 0 or more unescaped
)*
'
)'''.format(es=ESCAPE_SEQUENCE)
# For error detection
STRING_START = r'''[bB]?(?:"(?:[^"\\\n]|{es})*|'(?:[^'\\\n]|{es})*)'''.format(es=ESCAPE_SEQUENCE)
# Careful here: | is not greedy so hex must come first
INT_LITERAL = r'[+-]?(?:0[xX]{hd}+|[1-9][0-9]*|0)[lL]?'.format(hd=HEX_DIGIT)
FLOAT_LITERAL = r'''(?:
(?:
[-+][Ii][Nn][Ff][Ii][Nn][Ii][Tt][Yy]| # Nan and Inf both have mandatory sign
[-+][Nn][Aa][Nn]
(?:<0[xX]{hd}+>)? # Optional suffix for nonstandard NaNs
)|
[-+]?(?:
\d+\.\d+(?:[eE][+-]?\d+)?| # decimal float
\d+[eE][+-]?\d+| # decimal float with no fraction (exponent mandatory)
0[xX]{hd}+(?:\.{hd}+)?[pP][+-]?\d+ # hexidecimal float
)
)[fF]?
'''.format(hd=HEX_DIGIT)

View File

@ -0,0 +1,113 @@
from __future__ import print_function
import collections
import re
import sys
from . import token_regexes as res
class AsssemblerError(Exception):
pass
Token = collections.namedtuple('Token', 'type val pos')
TOKENS = [
('WHITESPACE', r'[ \t]+'),
('WORD', res.WORD + res.FOLLOWED_BY_WHITESPACE),
('DIRECTIVE', res.DIRECTIVE + res.FOLLOWED_BY_WHITESPACE),
('LABEL_DEF', res.LABEL_DEF + res.FOLLOWED_BY_WHITESPACE),
('NEWLINES', res.NEWLINES),
('REF', res.REF + res.FOLLOWED_BY_WHITESPACE),
('COLON', r':' + res.FOLLOWED_BY_WHITESPACE),
('EQUALS', r'=' + res.FOLLOWED_BY_WHITESPACE),
('INT_LITERAL', res.INT_LITERAL + res.FOLLOWED_BY_WHITESPACE),
('DOUBLE_LITERAL', res.FLOAT_LITERAL + res.FOLLOWED_BY_WHITESPACE),
('STRING_LITERAL', res.STRING_LITERAL + res.FOLLOWED_BY_WHITESPACE),
]
REGEX = re.compile('|'.join('(?P<{}>{})'.format(*pair) for pair in TOKENS), re.VERBOSE)
# For error detection
STRING_START_REGEX = re.compile(res.STRING_START)
WORD_LIKE_REGEX = re.compile(r'.\S*')
MAXLINELEN = 80
def formatError(source, filename, message, point, point2):
try:
start = source.rindex('\n', 0, point) + 1
except ValueError:
start = 0
line_start = start
try:
end = source.index('\n', start) + 1
except ValueError: # pragma: no cover
end = len(source) + 1
# Find an 80 char section of the line around the point of interest to display
temp = min(point2, point + MAXLINELEN//2)
if temp < start + MAXLINELEN:
end = min(end, start + MAXLINELEN)
elif point >= end - MAXLINELEN:
start = max(start, end - MAXLINELEN)
else:
mid = (point + temp) // 2
start = max(start, mid - MAXLINELEN//2)
end = min(end, start + MAXLINELEN)
point2 = min(point2, end)
assert line_start <= start <= point < point2 <= end
pchars = [' '] * (end - start)
for i in range(point - start, point2 - start):
pchars[i] = '~'
pchars[point - start] = '^'
lineno = source[:line_start].count('\n') + 1
colno = point - line_start + 1
return '{}:{}:{}: {}\n{}\n{}'.format(filename, lineno, colno,
message, source[start:end].rstrip('\n'), ''.join(pchars))
class Tokenizer(object):
def __init__(self, source, filename):
self.s = source
self.pos = 0
self.atlineend = True
if isinstance(filename, bytes):
filename = filename.decode()
self.filename = filename.rpartition('/')[-1]
def error(self, error, *notes):
printerr = lambda s: print(s, file=sys.stderr)
message, point, point2 = error
printerr(formatError(self.s, self.filename, 'error: ' + message, point, point2))
for message, point, point2 in notes:
printerr(formatError(self.s, self.filename, 'note: ' + message, point, point2))
raise AsssemblerError()
def _nextsub(self):
match = REGEX.match(self.s, self.pos)
if match is None:
if self.atend():
return Token('EOF', '', self.pos)
else:
str_match = STRING_START_REGEX.match(self.s, self.pos)
if str_match is not None:
self.error(('Invalid escape sequence or character in string literal', str_match.end(), str_match.end()+1))
match = WORD_LIKE_REGEX.match(self.s, self.pos)
return Token('INVALID_TOKEN', match.group(), match.start())
assert match.start() == match.pos == self.pos
self.pos = match.end()
return Token(match.lastgroup, match.group(), match.start())
def next(self):
tok = self._nextsub()
while tok.type == 'WHITESPACE' or self.atlineend and tok.type == 'NEWLINES':
tok = self._nextsub()
self.atlineend = tok.type == 'NEWLINES'
if tok.type == 'INT_LITERAL' and tok.val.lower().endswith('l'):
return tok._replace(type='LONG_LITERAL')
elif tok.type == 'DOUBLE_LITERAL' and tok.val.lower().endswith('f'):
return tok._replace(type='FLOAT_LITERAL')
return tok
def atend(self): return self.pos == len(self.s)

View File

@ -0,0 +1,138 @@
import collections
import struct
Label = collections.namedtuple('Label', ['tok', 'sym'])
class Writer(object):
def __init__(self):
self.b = bytearray()
self.refphs = []
self.refu8phs = []
self.lblphs = []
# includes lbl and manual phs but not ref phs
self._ph8s = set()
self._ph16s = set()
self._ph32s = set()
@property
def pos(self): return len(self.b)
def u8(self, x): self.b.append(x)
def s8(self, x): self.b.append(x % 256)
def u16(self, x): self.b.extend(struct.pack('>H', x))
def s16(self, x): self.b.extend(struct.pack('>h', x))
def u32(self, x): self.b.extend(struct.pack('>I', x))
def s32(self, x): self.b.extend(struct.pack('>i', x))
def u64(self, x): self.b.extend(struct.pack('>Q', x))
def writeBytes(self, b): self.b.extend(b)
def ref(self, ref):
self.refphs.append((self.pos, ref))
self.u16(0)
def refu8(self, ref):
self.refu8phs.append((self.pos, ref))
self.u8(0)
def ph8(self):
pos = self.pos
self.u8(0)
self._ph8s.add(pos)
return pos
def ph16(self):
pos = self.pos
self.u16(0)
self._ph16s.add(pos)
return pos
def ph32(self):
pos = self.pos
self.u32(0)
self._ph32s.add(pos)
return pos
def lbl(self, lbl, base, dtype):
pos = self.ph32() if dtype == 's32' else self.ph16()
self.lblphs.append((pos, lbl, base, dtype))
def lblrange(self, start, end):
self.lbl(start, 0, 'u16')
self.lbl(end, start, 'u16')
def setph8(self, pos, x):
assert self.b[pos] == 0
self.b[pos] = x
self._ph8s.remove(pos)
def setph16(self, pos, x):
assert self.b[pos:pos+2] == b'\0\0'
self.b[pos:pos+2] = struct.pack('>H', x)
self._ph16s.remove(pos)
def setph32(self, pos, x):
assert self.b[pos:pos+4] == b'\0\0\0\0'
self.b[pos:pos+4] = struct.pack('>I', x)
self._ph32s.remove(pos)
def _getlbl(self, lbl, labels, error):
if lbl.sym not in labels:
error('Undefined label', lbl.tok)
return labels[lbl.sym][1]
def fillLabels(self, labels, error):
for pos, lbl, base, dtype in self.lblphs:
tok = lbl.tok
lbl = self._getlbl(lbl, labels, error)
# base can also be a second label
if isinstance(base, Label):
base = self._getlbl(base, labels, error)
offset = lbl - base
if dtype == 's16':
if not -1<<15 <= offset < 1<<15:
error('Label offset must fit in signed 16 bit int. (offset is {})'.format(offset), tok)
self.setph16(pos, offset % (1<<16))
elif dtype == 'u16':
if not 0 <= offset < 1<<16:
error('Label offset must fit in unsigned 16 bit int. (offset is {})'.format(offset), tok)
self.setph16(pos, offset)
elif dtype == 's32':
if not -1<<31 <= offset < 1<<31:
error('Label offset must fit in signed 32 bit int. (offset is {})'.format(offset), tok)
self.setph32(pos, offset % (1<<32))
else:
assert 0 # pragma: no cover
self.lblphs = []
return self
def fillRefs(self, pool, error):
for pos, ref in self.refu8phs:
self.b[pos] = ref.resolve(pool, error)
for pos, ref in self.refphs:
self.b[pos:pos+2] = struct.pack('>H', ref.resolve(pool, error))
self.refu8phs = []
self.refphs = []
def toBytes(self):
assert not self.refphs and not self.refu8phs
assert not self._ph8s and not self._ph16s and not self._ph32s
return bytes(self.b)
def __len__(self): return len(self.b)
def __iadd__(self, other):
# Make sure there are no manual placeholders in other
assert len(other.lblphs) == len(other._ph8s) + len(other._ph16s) + len(other._ph32s)
offset = self.pos
self.b += other.b
self.refphs.extend((pos + offset, ref) for pos, ref in other.refphs)
self.refu8phs.extend((pos + offset, ref) for pos, ref in other.refu8phs)
self.lblphs.extend((pos + offset, lbl, base, dtype) for pos, lbl, base, dtype in other.lblphs)
self._ph8s.update(pos + offset for pos in other._ph8s)
self._ph16s.update(pos + offset for pos in other._ph16s)
self._ph32s.update(pos + offset for pos in other._ph32s)
return self

View File

@ -0,0 +1,21 @@
def get_attribute_raw(bytestream, ic_indices):
name_ind, length = bytestream.get('>HL')
# Hotspot does not actually check the attribute length of InnerClasses prior to 49.0
# so this case requires special handling. We will keep the purported length of the
# attribute so that it can be displayed in the disassembly. For InnerClass attributes
# data is actually a (length, bytes) tuple, rather than storing the bytes directly
if name_ind in ic_indices:
count = bytestream.get('>H', peek=True)
data = length, bytestream.getRaw(2+8*count)
else:
data = bytestream.getRaw(length)
return name_ind,data
def get_attributes_raw(bytestream, ic_indices=()):
attribute_count = bytestream.get('>H')
return [get_attribute_raw(bytestream, ic_indices) for _ in range(attribute_count)]
def fixAttributeNames(attributes_raw, cpool):
return [(cpool.getArgsCheck('Utf8', name_ind), data) for name_ind, data in attributes_raw]

View File

@ -0,0 +1,217 @@
from __future__ import division
from . import opnames
def parseInstructions(bytestream, isConstructor):
data = bytestream
assert data.off == 0
instructions = {}
while data.size() > 0:
address = data.off
inst = getNextInstruction(data, address)
# replace constructor invocations with synthetic op invokeinit to simplfy things later
if inst[0] == opnames.INVOKESPECIAL and isConstructor(inst[1]):
inst = (opnames.INVOKEINIT,) + inst[1:]
instructions[address] = inst
assert data.size() == 0
return instructions
simpleOps = {0x00:opnames.NOP, 0x01:opnames.CONSTNULL, 0x94:opnames.LCMP,
0xbe:opnames.ARRLEN, 0xbf:opnames.THROW, 0xc2:opnames.MONENTER,
0xc3:opnames.MONEXIT, 0x57:opnames.POP, 0x58:opnames.POP2, 0x59:opnames.DUP,
0x5a:opnames.DUPX1, 0x5b:opnames.DUPX2, 0x5c:opnames.DUP2,
0x5d:opnames.DUP2X1, 0x5e:opnames.DUP2X2, 0x5f:opnames.SWAP}
singleIndexOps = {0xb2:opnames.GETSTATIC,0xb3:opnames.PUTSTATIC,0xb4:opnames.GETFIELD,
0xb5:opnames.PUTFIELD,0xb6:opnames.INVOKEVIRTUAL,0xb7:opnames.INVOKESPECIAL,
0xb8:opnames.INVOKESTATIC, 0xbb:opnames.NEW,0xbd:opnames.ANEWARRAY,
0xc0:opnames.CHECKCAST,0xc1:opnames.INSTANCEOF}
def getNextInstruction(data, address):
byte = data.get('>B')
# typecode - B,C,S, and Bool are only used for array types and sign extension
A,B,C,D,F,I,L,S = "ABCDFIJS"
Bool = "Z"
if byte in simpleOps:
inst = (simpleOps[byte],)
elif byte in singleIndexOps:
inst = (singleIndexOps[byte], data.get('>H'))
elif byte <= 0x11:
op = opnames.CONST
if byte <= 0x08:
t, val = I, byte - 0x03
elif byte <= 0x0a:
t, val = L, byte - 0x09
elif byte <= 0x0d:
t, val = F, float(byte - 0x0b)
elif byte <= 0x0f:
t, val = D, float(byte - 0x0e)
elif byte == 0x10:
t, val = I, data.get('>b')
else:
t, val = I, data.get('>h')
inst = op, t, val
elif byte == 0x12:
inst = opnames.LDC, data.get('>B'), 1
elif byte == 0x13:
inst = opnames.LDC, data.get('>H'), 1
elif byte == 0x14:
inst = opnames.LDC, data.get('>H'), 2
elif byte <= 0x2d:
op = opnames.LOAD
if byte <= 0x19:
t = [I,L,F,D,A][byte - 0x15]
val = data.get('>B')
else:
temp = byte - 0x1a
t = [I,L,F,D,A][temp // 4]
val = temp % 4
inst = op, t, val
elif byte <= 0x35:
op = opnames.ARRLOAD
t = [I,L,F,D,A,B,C,S][byte - 0x2e]
inst = (op, t) if t != A else (opnames.ARRLOAD_OBJ,) # split object case into seperate op name to simplify things later
elif byte <= 0x4e:
op = opnames.STORE
if byte <= 0x3a:
t = [I,L,F,D,A][byte - 0x36]
val = data.get('>B')
else:
temp = byte - 0x3b
t = [I,L,F,D,A][temp // 4]
val = temp % 4
inst = op, t, val
elif byte <= 0x56:
op = opnames.ARRSTORE
t = [I,L,F,D,A,B,C,S][byte - 0x4f]
inst = (op, t) if t != A else (opnames.ARRSTORE_OBJ,) # split object case into seperate op name to simplify things later
elif byte <= 0x77:
temp = byte - 0x60
opt = (opnames.ADD,opnames.SUB,opnames.MUL,opnames.DIV,opnames.REM,opnames.NEG)[temp//4]
t = (I,L,F,D)[temp % 4]
inst = opt, t
elif byte <= 0x83:
temp = byte - 0x78
opt = (opnames.SHL,opnames.SHR,opnames.USHR,opnames.AND,opnames.OR,opnames.XOR)[temp//2]
t = (I,L)[temp % 2]
inst = opt, t
elif byte == 0x84:
inst = opnames.IINC, data.get('>B'), data.get('>b')
elif byte <= 0x90:
op = opnames.CONVERT
pairs = ((I,L),(I,F),(I,D),(L,I),(L,F),(L,D),(F,I),(F,L),(F,D),
(D,I),(D,L),(D,F))
src_t, dest_t = pairs[byte - 0x85]
inst = op, src_t, dest_t
elif byte <= 0x93:
op = opnames.TRUNCATE
dest_t = [B,C,S][byte - 0x91]
inst = op, dest_t
elif byte <= 0x98:
op = opnames.FCMP
temp = byte - 0x95
t = (F,D)[temp//2]
NaN_val = (-1,1)[temp % 2]
inst = op, t, NaN_val
elif byte <= 0x9e:
op = opnames.IF_I
cmp_t = ('eq','ne','lt','ge','gt','le')[byte - 0x99]
jumptarget = data.get('>h') + address
inst = op, cmp_t, jumptarget
elif byte <= 0xa4:
op = opnames.IF_ICMP
cmp_t = ('eq','ne','lt','ge','gt','le')[byte - 0x9f]
jumptarget = data.get('>h') + address
inst = op, cmp_t, jumptarget
elif byte <= 0xa6:
op = opnames.IF_ACMP
cmp_t = ('eq','ne')[byte - 0xa5]
jumptarget = data.get('>h') + address
inst = op, cmp_t, jumptarget
elif byte == 0xa7:
inst = opnames.GOTO, data.get('>h') + address
elif byte == 0xa8:
inst = opnames.JSR, data.get('>h') + address
elif byte == 0xa9:
inst = opnames.RET, data.get('>B')
elif byte == 0xaa: # Table Switch
padding = data.getRaw((3-address) % 4)
default = data.get('>i') + address
low = data.get('>i')
high = data.get('>i')
assert high >= low
numpairs = high - low + 1
offsets = [data.get('>i') + address for _ in range(numpairs)]
jumps = zip(range(low, high+1), offsets)
inst = opnames.SWITCH, default, jumps
elif byte == 0xab: # Lookup Switch
padding = data.getRaw((3-address) % 4)
default = data.get('>i') + address
numpairs = data.get('>i')
assert numpairs >= 0
pairs = [data.get('>ii') for _ in range(numpairs)]
jumps = [(x,(y + address)) for x,y in pairs]
inst = opnames.SWITCH, default, jumps
elif byte <= 0xb1:
op = opnames.RETURN
t = (I,L,F,D,A,None)[byte - 0xac]
inst = op, t
elif byte == 0xb9:
op = opnames.INVOKEINTERFACE
index = data.get('>H')
count, zero = data.get('>B'), data.get('>B')
inst = op, index, count, zero
elif byte == 0xba:
op = opnames.INVOKEDYNAMIC
index = data.get('>H')
zero = data.get('>H')
inst = op, index, zero
elif byte == 0xbc:
typecode = data.get('>b')
types = {4:Bool, 5:C, 6:F, 7:D, 8:B, 9:S, 10:I, 11:L}
t = types.get(typecode)
inst = opnames.NEWARRAY, t
elif byte == 0xc4: # wide
realbyte = data.get('>B')
if realbyte >= 0x15 and realbyte < 0x1a:
t = [I,L,F,D,A][realbyte - 0x15]
inst = opnames.LOAD, t, data.get('>H')
elif realbyte >= 0x36 and realbyte < 0x3b:
t = [I,L,F,D,A][realbyte - 0x36]
inst = opnames.STORE, t, data.get('>H')
elif realbyte == 0xa9:
inst = opnames.RET, data.get('>H')
elif realbyte == 0x84:
inst = opnames.IINC, data.get('>H'), data.get('>h')
else:
assert 0
elif byte == 0xc5:
op = opnames.MULTINEWARRAY
index = data.get('>H')
dim = data.get('>B')
inst = op, index, dim
elif byte <= 0xc7:
op = opnames.IF_A
cmp_t = ('eq','ne')[byte - 0xc6]
jumptarget = data.get('>h') + address
inst = op, cmp_t, jumptarget
elif byte == 0xc8:
inst = opnames.GOTO, data.get('>i') + address
elif byte == 0xc9:
inst = opnames.JSR, data.get('>i') + address
else:
assert 0
return inst
def printInstruction(instr):
if len(instr) == 1:
return instr[0]
elif len(instr) == 2:
return '{}({})'.format(*instr)
else:
return '{}{}'.format(instr[0], instr[1:])

View File

@ -0,0 +1,107 @@
from . import constant_pool, field, method
from .attributes_raw import fixAttributeNames, get_attributes_raw
cp_structFmts = {3: '>i',
4: '>i', # floats and doubles internally represented as integers with same bit pattern
5: '>q',
6: '>q',
7: '>H',
8: '>H',
9: '>HH',
10: '>HH',
11: '>HH',
12: '>HH',
15: '>BH',
16: '>H',
18: '>HH'}
def get_cp_raw(bytestream):
const_count = bytestream.get('>H')
assert const_count > 1
placeholder = None,None
pool = [placeholder]
while len(pool) < const_count:
tag = bytestream.get('B')
if tag == 1: # utf8
length = bytestream.get('>H')
data = bytestream.getRaw(length)
val = tag, (data,)
else:
val = tag,bytestream.get(cp_structFmts[tag], True)
pool.append(val)
# Longs and Doubles take up two spaces in the pool
if tag == 5 or tag == 6:
pool.append(placeholder)
assert len(pool) == const_count
return pool
def get_field_raw(bytestream):
flags, name, desc = bytestream.get('>HHH')
attributes = get_attributes_raw(bytestream)
return flags, name, desc, attributes
def get_fields_raw(bytestream):
count = bytestream.get('>H')
return [get_field_raw(bytestream) for _ in range(count)]
# fields and methods have same raw format
get_method_raw = get_field_raw
get_methods_raw = get_fields_raw
class ClassFile(object):
flagVals = {'PUBLIC':0x0001,
'FINAL':0x0010,
'SUPER':0x0020,
'INTERFACE':0x0200,
'ABSTRACT':0x0400,
'SYNTHETIC':0x1000,
'ANNOTATION':0x2000,
'ENUM':0x4000,
# These flags are only used for InnerClasses attributes
'PRIVATE':0x0002,
'PROTECTED':0x0004,
'STATIC':0x0008,
}
def __init__(self, bytestream):
magic, minor, major = bytestream.get('>LHH')
assert magic == 0xCAFEBABE
self.version = major,minor
const_pool_raw = get_cp_raw(bytestream)
flags, self.this, self.super = bytestream.get('>HHH')
interface_count = bytestream.get('>H')
self.interfaces_raw = [bytestream.get('>H') for _ in range(interface_count)]
self.fields_raw = get_fields_raw(bytestream)
self.methods_raw = get_methods_raw(bytestream)
ic_indices = [i for i,x in enumerate(const_pool_raw) if x == (1, ("InnerClasses",))]
self.attributes_raw = get_attributes_raw(bytestream, ic_indices)
assert bytestream.size() == 0
self.flags = frozenset(name for name,mask in ClassFile.flagVals.items() if (mask & flags))
self.cpool = constant_pool.ConstPool(const_pool_raw)
self.name = self.cpool.getArgsCheck('Class', self.this)
self.elementsLoaded = False
self.env = self.supername = None
self.fields = self.methods = self.attributes = None
if self.super:
self.supername = self.cpool.getArgsCheck('Class', self.super)
def loadElements(self, keepRaw=False):
if self.elementsLoaded:
return
self.fields = [field.Field(m, self, keepRaw) for m in self.fields_raw]
self.methods = [method.Method(m, self, keepRaw) for m in self.methods_raw]
self.attributes = fixAttributeNames(self.attributes_raw, self.cpool)
self.fields_raw = self.methods_raw = None
if not keepRaw:
self.attributes_raw = None
self.elementsLoaded = True

View File

@ -0,0 +1,116 @@
import collections
from .reader import Reader
TAGS = [None, 'Utf8', None, 'Int', 'Float', 'Long', 'Double', 'Class', 'String', 'Field', 'Method', 'InterfaceMethod', 'NameAndType', None, None, 'MethodHandle', 'MethodType', None, 'InvokeDynamic', 'Module', 'Package']
SlotData = collections.namedtuple('SlotData', ['tag', 'data', 'refs'])
ExceptData = collections.namedtuple('ExceptData', ['start', 'end', 'handler', 'type'])
class ConstantPoolData(object):
def __init__(self, r):
self.slots = []
self._null()
size = r.u16()
while len(self.slots) < size:
self._const(r)
def _null(self):
self.slots.append(SlotData(None, None, None))
def _const(self, r):
t = TAGS[r.u8()]
data = None
refs = []
if t == 'Utf8':
data = r.getRaw(r.u16())
elif t == 'Int' or t == 'Float':
data = r.u32()
elif t == 'Long' or t == 'Double':
data = r.u64()
elif t == 'MethodHandle':
data = r.u8()
refs.append(r.u16())
elif t in ['Class', 'String', 'MethodType', 'Module', 'Package']:
refs.append(r.u16())
else:
refs.append(r.u16())
refs.append(r.u16())
self.slots.append(SlotData(t, data, refs))
if t in ('Long', 'Double'):
self._null()
def getutf(self, ind):
if ind < len(self.slots) and self.slots[ind].tag == 'Utf8':
return self.slots[ind].data
def getclsutf(self, ind):
if ind < len(self.slots) and self.slots[ind].tag == 'Class':
return self.getutf(self.slots[ind].refs[0])
class BootstrapMethodsData(object):
def __init__(self, r):
self.slots = []
for _ in range(r.u16()):
first = r.u16()
argcount = r.u16()
refs = [first] + [r.u16() for _ in range(argcount)]
self.slots.append(SlotData('Bootstrap', None, refs))
class CodeData(object):
def __init__(self, r, pool, short):
if short:
self.stack, self.locals, codelen = r.u8(), r.u8(), r.u16()
else:
self.stack, self.locals, codelen = r.u16(), r.u16(), r.u32()
self.bytecode = r.getRaw(codelen)
self.exceptions = [ExceptData(r.u16(), r.u16(), r.u16(), r.u16()) for _ in range(r.u16())]
self.attributes = [AttributeData(r) for _ in range(r.u16())]
class AttributeData(object):
def __init__(self, r, pool=None):
self.name, self.length = r.u16(), r.u32()
# The JVM allows InnerClasses attributes to have a bogus length field,
# and hence we must calculate the length from the contents
if pool and pool.getutf(self.name) == b'InnerClasses':
actual_length = r.copy().u16() * 8 + 2
else:
actual_length = self.length
self.raw = r.getRaw(actual_length)
self.wronglength = actual_length != self.length
def stream(self): return Reader(self.raw)
class FieldData(object):
def __init__(self, r):
self.access, self.name, self.desc = r.u16(), r.u16(), r.u16()
self.attributes = [AttributeData(r) for _ in range(r.u16())]
class MethodData(object):
def __init__(self, r):
self.access, self.name, self.desc = r.u16(), r.u16(), r.u16()
self.attributes = [AttributeData(r) for _ in range(r.u16())]
class ClassData(object):
def __init__(self, r):
magic, minor, major = r.u32(), r.u16(), r.u16()
self.version = major, minor
self.pool = ConstantPoolData(r)
self.access, self.this, self.super = r.u16(), r.u16(), r.u16()
self.interfaces = [r.u16() for _ in range(r.u16())]
self.fields = [FieldData(r) for _ in range(r.u16())]
self.methods = [MethodData(r) for _ in range(r.u16())]
self.attributes = [AttributeData(r, pool=self.pool) for _ in range(r.u16())]
# assert r.done()
def getattrs(self, name):
for attr in self.attributes:
if self.pool.getutf(attr.name) == name:
yield attr

View File

@ -0,0 +1,28 @@
import re
# First alternative handles a single surrogate, in case input string somehow contains unmerged surrogates
NONASTRAL_REGEX = re.compile(u'[\ud800-\udfff]|[\0-\ud7ff\ue000-\uffff]+')
def encode(s):
assert not isinstance(s, bytes)
b = b''
pos = 0
while pos < len(s):
x = ord(s[pos])
if x >= 1<<16:
x -= 1<<16
high = 0xD800 + (x >> 10)
low = 0xDC00 + (x % (1 << 10))
b += unichr(high).encode('utf8')
b += unichr(low).encode('utf8')
pos += 1
else:
m = NONASTRAL_REGEX.match(s, pos)
b += m.group().encode('utf8')
pos = m.end()
return b.replace(b'\0', b'\xc0\x80')
# Warning, decode(encode(s)) != s if s contains astral characters, as they are converted to surrogate pairs
def decode(b):
assert isinstance(b, bytes)
return b.replace(b'\xc0\x80', b'\0').decode('utf8')

View File

@ -0,0 +1,47 @@
import struct
class TruncatedStreamError(EOFError):
pass
class Reader(object):
__slots__ = ['d', 'off']
def __init__(self, data, off=0):
self.d = data
self.off = off
def done(self): return self.off >= len(self.d)
def copy(self): return Reader(self.d, self.off)
def u8(self): return self.get('>B')
def s8(self): return self.get('>b')
def u16(self): return self.get('>H')
def s16(self): return self.get('>h')
def u32(self): return self.get('>I')
def s32(self): return self.get('>i')
def u64(self): return self.get('>Q')
# binUnpacker functions
def get(self, fmt, forceTuple=False, peek=False):
size = struct.calcsize(fmt)
if self.size() < size:
raise TruncatedStreamError()
val = struct.unpack_from(fmt, self.d, self.off)
if not peek:
self.off += size
if not forceTuple and len(val) == 1:
val = val[0]
return val
def getRaw(self, num):
if self.size() < num:
raise TruncatedStreamError()
val = self.d[self.off:self.off+num]
self.off += num
return val
def size(self):
return len(self.d) - self.off

View File

@ -0,0 +1,103 @@
import collections
import struct
# ConstantPool stores strings as strings or unicodes. They are automatically
# converted to and from modified Utf16 when reading and writing to binary
# Floats and Doubles are internally stored as integers with the same bit pattern
# Since using raw floats breaks equality testing for signed zeroes and NaNs
# cpool.getArgs/getArgsCheck will automatically convert them into Python floats
def decodeStr(s):
return s.replace('\xc0\x80','\0').decode('utf8'),
def decodeFloat(i):
return struct.unpack('>f', struct.pack('>i', i)) # Note: returns tuple
def decodeDouble(i):
return struct.unpack('>d', struct.pack('>q', i))
cpoolInfo_t = collections.namedtuple('cpoolInfo_t',
['name','tag','recoverArgs'])
Utf8 = cpoolInfo_t('Utf8',1,
(lambda self,s:(s,)))
Class = cpoolInfo_t('Class',7,
(lambda self,n_id:self.getArgs(n_id)))
NameAndType = cpoolInfo_t('NameAndType',12,
(lambda self,n,d:self.getArgs(n) + self.getArgs(d)))
Field = cpoolInfo_t('Field',9,
(lambda self,c_id,nat_id:self.getArgs(c_id) + self.getArgs(nat_id)))
Method = cpoolInfo_t('Method',10,
(lambda self,c_id,nat_id:self.getArgs(c_id) + self.getArgs(nat_id)))
InterfaceMethod = cpoolInfo_t('InterfaceMethod',11,
(lambda self,c_id,nat_id:self.getArgs(c_id) + self.getArgs(nat_id)))
String = cpoolInfo_t('String',8,
(lambda self,n_id:self.getArgs(n_id)))
Int = cpoolInfo_t('Int',3,
(lambda self,s:(s,)))
Long = cpoolInfo_t('Long',5,
(lambda self,s:(s,)))
Float = cpoolInfo_t('Float',4,
(lambda self,s:decodeFloat(s)))
Double = cpoolInfo_t('Double',6,
(lambda self,s:decodeDouble(s)))
MethodHandle = cpoolInfo_t('MethodHandle',15,
(lambda self,t,n_id:(t,)+self.getArgs(n_id)))
MethodType = cpoolInfo_t('MethodType',16,
(lambda self,n_id:self.getArgs(n_id)))
InvokeDynamic = cpoolInfo_t('InvokeDynamic',18,
(lambda self,bs_id,nat_id:(bs_id,) + self.getArgs(nat_id)))
cpoolTypes = [Utf8, Class, NameAndType, Field, Method, InterfaceMethod,
String, Int, Long, Float, Double,
MethodHandle, MethodType, InvokeDynamic]
name2Type = {t.name:t for t in cpoolTypes}
tag2Type = {t.tag:t for t in cpoolTypes}
class ConstPool(object):
def __init__(self, initialData=((None,None),)):
self.pool = []
self.reserved = set()
self.available = set()
for tag, val in initialData:
if tag is None:
self.addEmptySlot()
else:
t = tag2Type[tag]
if t.name == 'Utf8':
val = decodeStr(*val)
self.pool.append((t.name, val))
def addEmptySlot(self):
self.pool.append((None, None))
def getArgs(self, i):
if not (i >= 0 and i<len(self.pool)):
raise IndexError('Constant pool index {} out of range'.format(i))
if self.pool[i][0] is None:
raise IndexError('Constant pool index {} invalid'.format(i))
name, val = self.pool[i]
t = name2Type[name]
return t.recoverArgs(self, *val)
def getArgsCheck(self, typen, index):
# if (self.pool[index][0] != typen):
# raise KeyError('Constant pool index {} has incorrect type {}'.format(index, typen))
val = self.getArgs(index)
return val if len(val) > 1 else val[0]
def getType(self, index): return self.pool[index][0]

View File

@ -0,0 +1,110 @@
import os.path
import zipfile
from .classfile import ClassFile
from .classfileformat.reader import Reader
from .error import ClassLoaderError
class Environment(object):
def __init__(self):
self.classes = {}
self.path = []
self._open = {}
def addToPath(self, path):
self.path.append(path)
def _getSuper(self, name):
return self.getClass(name).supername
def getClass(self, name, partial=False):
try:
result = self.classes[name]
except KeyError:
result = self._loadClass(name)
if not partial:
result.loadElements()
return result
def isSubclass(self, name1, name2):
if name2 == 'java/lang/Object':
return True
while name1 != 'java/lang/Object':
if name1 == name2:
return True
name1 = self._getSuper(name1)
return False
def commonSuperclass(self, name1, name2):
a, b = name1, name2
supers = {a}
while a != b and a != 'java/lang/Object':
a = self._getSuper(a)
supers.add(a)
while b not in supers:
b = self._getSuper(b)
return b
def isInterface(self, name, forceCheck=False):
try:
class_ = self.getClass(name, partial=True)
return 'INTERFACE' in class_.flags
except ClassLoaderError as e:
if forceCheck:
raise e
# If class is not found, assume worst case, that it is a interface
return True
def isFinal(self, name):
try:
class_ = self.getClass(name, partial=True)
return 'FINAL' in class_.flags
except ClassLoaderError as e:
return False
def _searchForFile(self, name):
name += '.class'
for place in self.path:
try:
archive = self._open[place]
except KeyError: # plain folder
try:
path = os.path.join(place, name)
with open(path, 'rb') as file_:
return file_.read()
except IOError:
print 'failed to open', path.encode('utf8')
else: # zip archive
try:
return archive.read(name)
except KeyError:
pass
def _loadClass(self, name):
print "Loading", name[:70]
data = self._searchForFile(name)
if data is None:
raise ClassLoaderError('ClassNotFoundException', name)
stream = Reader(data=data)
new = ClassFile(stream)
new.env = self
self.classes[new.name] = new
return new
# Context Manager methods to manager our zipfiles
def __enter__(self):
assert not self._open
for place in self.path:
if place.endswith('.jar') or place.endswith('.zip'):
self._open[place] = zipfile.ZipFile(place, 'r').__enter__()
return self
def __exit__(self, type_, value, traceback):
for place in reversed(self.path):
if place in self._open:
self._open[place].__exit__(type_, value, traceback)
del self._open[place]

View File

@ -0,0 +1,12 @@
class ClassLoaderError(Exception):
def __init__(self, typen=None, data=""):
self.type = typen
self.data = data
message = "\n{}: {}".format(typen, data) if typen else data
super(ClassLoaderError, self).__init__(message)
class VerificationError(Exception):
def __init__(self, message, data=None):
super(VerificationError, self).__init__(message)
self.data = data

View File

@ -0,0 +1,29 @@
from .attributes_raw import fixAttributeNames
class Field(object):
flagVals = {'PUBLIC':0x0001,
'PRIVATE':0x0002,
'PROTECTED':0x0004,
'STATIC':0x0008,
'FINAL':0x0010,
'VOLATILE':0x0040,
'TRANSIENT':0x0080,
'SYNTHETIC':0x1000,
'ENUM':0x4000,
}
def __init__(self, data, classFile, keepRaw):
self.class_ = classFile
cpool = self.class_.cpool
flags, name_id, desc_id, attributes_raw = data
self.name = cpool.getArgsCheck('Utf8', name_id)
self.descriptor = cpool.getArgsCheck('Utf8', desc_id)
self.attributes = fixAttributeNames(attributes_raw, cpool)
self.flags = set(name for name,mask in Field.flagVals.items() if (mask & flags))
self.static = 'STATIC' in self.flags
if keepRaw:
self.attributes_raw = attributes_raw
self.name_id, self.desc_id = name_id, desc_id

View File

@ -0,0 +1,81 @@
import math
INF_MAG = 1, None
ZERO_MAG = 0, None
# Numbers are represented as (sign, (mantissa, exponent))
# For finite nonzero values, the float value is sign * mantissa * 2 ^ (exponent - mbits - 1)
# Mantissa is normalized to always be within (2 ^ mbits) <= m < (2 ^ mbits + 1) even for subnormal numbers
NAN = None,(None,None)
INF = 1,INF_MAG
NINF = -1,INF_MAG
ZERO = 1,ZERO_MAG
NZERO = -1,ZERO_MAG
# Key suitable for sorting finite (normalized) nonzero values
sortkey = lambda (s,(m,e)):(s,s*e,s*m)
# Size info for type - mantissa bits, min exponent, max exponent
FLOAT_SIZE = 23,-126,127
DOUBLE_SIZE = 52,-1022,1023
def flog(x):
'''returns f such that 2**f <= x < 2**(f+1)'''
assert x > 0
return len(bin(x))-3
def roundMag(size, mag):
'''Round (unnormalized) magnitude to nearest representable magnitude with ties going to 0 lsb'''
mbits, emin, emax = size
m, e = mag
assert m >= 1
f = flog(m)
if e+f < emin: # subnormal
dnmin = emin - mbits
if e+f < (dnmin - 1):
return ZERO_MAG
if e > dnmin:
m = m << (e - dnmin)
f += (e - dnmin)
e = dnmin
s = dnmin - e
i = m >> s
r = (m - (i << s)) * 2
h = 1 << s
if r > h or r == h and (i&1):
i += 1
return i, e+s-mbits-1
else:
if f < mbits:
m = m << (mbits - f)
f = mbits
s = f - mbits
if (e+f) > emax:
return INF_MAG
i = m >> s
r = (m - (i << s)) * 2
h = 1 << s
if r > h or r == h and (i&1):
i += 1
if i == (1<<mbits):
i = i >> 1
e += 1
if e > emax:
return INF_MAG
return i, e+s-mbits-1
def fromRawFloat(size, x):
if math.isnan(x):
return NAN
sign = int(math.copysign(1, x))
x = math.copysign(x, 1)
if math.isinf(x):
return sign, INF_MAG
elif x == 0.0:
return sign, ZERO_MAG
else:
m, e = math.frexp(x)
m = int(m * (1<<(size[0]+1)))
return sign, roundMag(size, (m, e))

View File

@ -0,0 +1,68 @@
import itertools
def tarjanSCC(roots, getChildren):
"""Return a list of strongly connected components in a graph. If getParents is passed instead of getChildren, the result will be topologically sorted.
roots - list of root nodes to search from
getChildren - function which returns children of a given node
"""
sccs = []
indexCounter = itertools.count()
index = {}
lowlink = {}
removed = set()
subtree = []
# Use iterative version to avoid stack limits for large datasets
stack = [(node, 0) for node in roots]
while stack:
current, state = stack.pop()
if state == 0: # before recursing
if current not in index: # if it's in index, it was already visited (possibly earlier on the current search stack)
lowlink[current] = index[current] = next(indexCounter)
subtree.append(current)
stack.append((current, 1))
stack.extend((child, 0) for child in getChildren(current) if child not in removed)
else: # after recursing
children = [child for child in getChildren(current) if child not in removed]
for child in children:
if index[child] <= index[current]: # backedge (or selfedge)
lowlink[current] = min(lowlink[current], index[child])
else:
lowlink[current] = min(lowlink[current], lowlink[child])
assert lowlink[current] <= index[current]
if index[current] == lowlink[current]:
scc = []
while not scc or scc[-1] != current:
scc.append(subtree.pop())
sccs.append(tuple(scc))
removed.update(scc)
return sccs
def topologicalSort(roots, getParents):
"""Return a topological sorting of nodes in a graph.
roots - list of root nodes to search from
getParents - function which returns the parents of a given node
"""
results = []
visited = set()
# Use iterative version to avoid stack limits for large datasets
stack = [(node,0) for node in roots]
while stack:
current, state = stack.pop()
if state == 0: # before recursing
if current not in visited:
visited.add(current)
stack.append((current,1))
stack.extend((parent,0) for parent in getParents(current))
else: # after recursing
assert current in visited
results.append(current)
return results

View File

@ -0,0 +1,734 @@
import itertools
import math
from ..ssa import objtypes
from . import visitor
from .stringescape import escapeString
# Explicitly cast parameters to the desired type in order to avoid potential issues with overloaded methods
ALWAYS_CAST_PARAMS = 1
class VariableDeclarator(object):
def __init__(self, typename, identifier): self.typename = typename; self.local = identifier
def print_(self, printer, print_):
return '{} {}'.format(print_(self.typename), print_(self.local))
def tree(self, printer, tree): return [tree(self.typename), tree(self.local)]
#############################################################################################################################################
class JavaStatement(object):
expr = None # provide default for subclasses that don't have an expression
def getScopes(self): return ()
def fixLiterals(self):
if self.expr is not None:
self.expr = self.expr.fixLiterals()
def addCastsAndParens(self, env):
if self.expr is not None:
self.expr.addCasts(env)
self.expr.addParens()
class ExpressionStatement(JavaStatement):
def __init__(self, expr):
self.expr = expr
assert expr is not None
def print_(self, printer, print_): return print_(self.expr) + ';'
def tree(self, printer, tree): return [self.__class__.__name__, tree(self.expr)]
class LocalDeclarationStatement(JavaStatement):
def __init__(self, decl, expr=None):
self.decl = decl
self.expr = expr
def print_(self, printer, print_):
if self.expr is not None:
return '{} = {};'.format(print_(self.decl), print_(self.expr))
return print_(self.decl) + ';'
def tree(self, printer, tree): return [self.__class__.__name__, tree(self.expr), tree(self.decl)]
def addCastsAndParens(self, env):
if self.expr is not None:
self.expr.addCasts(env)
if not isJavaAssignable(env, self.expr.dtype, self.decl.typename.tt):
self.expr = makeCastExpr(self.decl.typename.tt, self.expr, fixEnv=env)
self.expr.addParens()
class ReturnStatement(JavaStatement):
def __init__(self, expr=None, tt=None):
self.expr = expr
self.tt = tt
def print_(self, printer, print_): return 'return {};'.format(print_(self.expr)) if self.expr is not None else 'return;'
def tree(self, printer, tree): return [self.__class__.__name__, tree(self.expr)]
def addCastsAndParens(self, env):
if self.expr is not None:
self.expr.addCasts(env)
if not isJavaAssignable(env, self.expr.dtype, self.tt):
self.expr = makeCastExpr(self.tt, self.expr, fixEnv=env)
self.expr.addParens()
class ThrowStatement(JavaStatement):
def __init__(self, expr):
self.expr = expr
def print_(self, printer, print_): return 'throw {};'.format(print_(self.expr))
def tree(self, printer, tree): return [self.__class__.__name__, tree(self.expr)]
class JumpStatement(JavaStatement):
def __init__(self, target, isFront):
self.label = target.getLabel() if target is not None else None
self.keyword = 'continue' if isFront else 'break'
def print_(self, printer, print_):
label = ' ' + self.label if self.label is not None else ''
return self.keyword + label + ';'
def tree(self, printer, tree): return [self.__class__.__name__, self.keyword, self.label]
# Compound Statements
sbcount = itertools.count()
class LazyLabelBase(JavaStatement):
# Jumps are represented by arbitrary 'keys', currently just the key of the
# original proxy node. Each item has a continueKey and a breakKey representing
# the beginning and the point just past the end respectively. breakKey may be
# None if this item appears at the end of the function and there is nothing after it.
# Statement blocks have a jump key representing where it jumps to if any. This
# may be None if the jump is unreachable (such as if there is a throw or return)
def __init__(self, labelfunc, begink, endk):
self.label, self.func = None, labelfunc
self.continueKey = begink
self.breakKey = endk
self.id = next(sbcount) # For debugging purposes
def getLabel(self):
if self.label is None:
self.label = self.func() # Not a bound function!
return self.label
def getLabelPrefix(self): return '' if self.label is None else self.label + ': '
# def getLabelPrefix(self): return self.getLabel() + ': '
# For debugging
def __str__(self): # pragma: no cover
if isinstance(self, StatementBlock):
return 'Sb'+str(self.id)
return type(self).__name__[:3]+str(self.id)
__repr__ = __str__
class TryStatement(LazyLabelBase):
def __init__(self, labelfunc, begink, endk, tryb, pairs):
super(TryStatement, self).__init__(labelfunc, begink, endk)
self.tryb, self.pairs = tryb, pairs
def getScopes(self): return (self.tryb,) + zip(*self.pairs)[1]
def print_(self, printer, print_):
tryb = print_(self.tryb)
parts = ['catch({}) {}'.format(print_(x), print_(y)) for x,y in self.pairs]
return '{}try {} {}'.format(self.getLabelPrefix(), tryb, '\n'.join(parts))
def tree(self, printer, tree):
parts = [map(tree, t) for t in self.pairs]
return [self.__class__.__name__, self.label, tree(self.tryb), parts]
class IfStatement(LazyLabelBase):
def __init__(self, labelfunc, begink, endk, expr, scopes):
super(IfStatement, self).__init__(labelfunc, begink, endk)
self.expr = expr # don't rename without changing how var replacement works!
self.scopes = scopes
def getScopes(self): return self.scopes
def print_(self, printer, print_):
lbl = self.getLabelPrefix()
parts = [self.expr] + list(self.scopes)
if len(self.scopes) == 1:
parts = [print_(x) for x in parts]
return '{}if ({}) {}'.format(lbl, *parts)
# Special case handling for 'else if'
fblock = self.scopes[1]
if len(fblock.statements) == 1:
stmt = fblock.statements[-1]
if isinstance(stmt, IfStatement) and stmt.label is None:
parts[-1] = stmt
parts = [print_(x) for x in parts]
return '{}if ({}) {} else {}'.format(lbl, *parts)
def tree(self, printer, tree): return [self.__class__.__name__, self.label, tree(self.expr), map(tree, self.scopes)]
class SwitchStatement(LazyLabelBase):
def __init__(self, labelfunc, begink, endk, expr, pairs):
super(SwitchStatement, self).__init__(labelfunc, begink, endk)
self.expr = expr # don't rename without changing how var replacement works!
self.pairs = pairs
def getScopes(self): return zip(*self.pairs)[1]
def hasDefault(self): return None in zip(*self.pairs)[0]
def print_(self, printer, print_):
expr = print_(self.expr)
def printCase(keys):
if keys is None:
return 'default: '
assert keys
return ''.join(map('case {}: '.format, sorted(keys)))
bodies = [(printCase(keys) + print_(scope)) for keys, scope in self.pairs]
if self.pairs[-1][0] is None and len(self.pairs[-1][1].statements) == 0:
bodies.pop()
contents = '\n'.join(bodies)
indented = [' '+line for line in contents.splitlines()]
lines = ['{'] + indented + ['}']
return '{}switch({}) {}'.format(self.getLabelPrefix(), expr, '\n'.join(lines))
def tree(self, printer, tree):
parts = []
for keys, scope in self.pairs:
parts.append([[None] if keys is None else sorted(keys), tree(scope)])
return [self.__class__.__name__, self.label, tree(self.expr), parts]
class WhileStatement(LazyLabelBase):
def __init__(self, labelfunc, begink, endk, parts):
super(WhileStatement, self).__init__(labelfunc, begink, endk)
self.expr = Literal.TRUE
self.parts = parts
assert len(self.parts) == 1
def getScopes(self): return self.parts
def print_(self, printer, print_):
parts = print_(self.expr), print_(self.parts[0])
return '{}while({}) {}'.format(self.getLabelPrefix(), *parts)
def tree(self, printer, tree): return [self.__class__.__name__, self.label, tree(self.expr), tree(self.parts[0])]
class StatementBlock(LazyLabelBase):
def __init__(self, labelfunc, begink, endk, statements, jumpk, labelable=True):
super(StatementBlock, self).__init__(labelfunc, begink, endk)
self.parent = None # should be assigned later
self.statements = statements
self.jumpKey = jumpk
self.labelable = labelable
def doesFallthrough(self): return self.jumpKey is None or self.jumpKey == self.breakKey
def getScopes(self): return self,
def print_(self, printer, print_):
assert self.labelable or self.label is None
contents = '\n'.join(print_(x) for x in self.statements)
indented = [' '+line for line in contents.splitlines()]
# indented[:0] = [' //{} {}'.format(self,x) for x in (self.continueKey, self.breakKey, self.jumpKey)]
lines = [self.getLabelPrefix() + '{'] + indented + ['}']
return '\n'.join(lines)
@staticmethod
def join(*scopes):
blists = [s.bases for s in scopes if s is not None] # allow None to represent the universe (top element)
if not blists:
return None
common = [x for x in zip(*blists) if len(set(x)) == 1]
return common[-1][0]
def tree(self, printer, tree): return ['BlockStatement', self.label, map(tree, self.statements)]
#############################################################################################################################################
# Careful, order is important here!
_assignable_sprims = objtypes.ByteTT, objtypes.ShortTT, objtypes.CharTT
_assignable_lprims = objtypes.IntTT, objtypes.LongTT, objtypes.FloatTT, objtypes.DoubleTT
# Also used in boolize.py
def isPrimativeAssignable(x, y): # x = fromt, y = to
assert objtypes.dim(x) == objtypes.dim(y) == 0
if x == y or (x in _assignable_sprims and y in _assignable_lprims):
return True
elif (x in _assignable_lprims and y in _assignable_lprims):
return _assignable_lprims.index(x) <= _assignable_lprims.index(y)
else:
return (x, y) == (objtypes.ByteTT, objtypes.ShortTT)
def isReferenceType(tt):
return tt == objtypes.NullTT or objtypes.dim(tt) or (objtypes.className(tt) is not None)
def isJavaAssignable(env, fromt, to):
if fromt is None or to is None: # this should never happen, except during debugging
return True
if isReferenceType(to):
assert isReferenceType(fromt)
# todo - make it check interfaces too
return objtypes.isSubtype(env, fromt, to)
else: # allowed if numeric conversion is widening
return isPrimativeAssignable(fromt, to)
_int_tts = objtypes.LongTT, objtypes.IntTT, objtypes.ShortTT, objtypes.CharTT, objtypes.ByteTT
def makeCastExpr(newtt, expr, fixEnv=None):
if newtt == expr.dtype:
return expr
# if casting a literal with compatible type, just create a literal of the new type
if isinstance(expr, Literal):
allowed_conversions = [
(objtypes.FloatTT, objtypes.DoubleTT),
(objtypes.IntTT, objtypes.LongTT),
(objtypes.IntTT, objtypes.BoolTT),
(objtypes.BoolTT, objtypes.IntTT),
]
if (expr.dtype, newtt) in allowed_conversions:
return Literal(newtt, expr.val)
if newtt == objtypes.IntTT and expr.dtype == objtypes.BoolTT:
return Ternary(expr, Literal.ONE, Literal.ZERO)
elif newtt == objtypes.BoolTT and expr.dtype == objtypes.IntTT:
return BinaryInfix('!=', [expr, Literal.ZERO], objtypes.BoolTT)
ret = Cast(TypeName(newtt), expr)
if fixEnv is not None:
ret = ret.fix(fixEnv)
return ret
#############################################################################################################################################
# Precedence:
# 0 - pseudoprimary
# 5 - pseudounary
# 10-19 binary infix
# 20 - ternary
# 21 - assignment
# Associativity: L = Left, R = Right, A = Full
class JavaExpression(object):
precedence = 0 # Default precedence
params = [] # for subclasses that don't have params
def complexity(self): return 1 + max(e.complexity() for e in self.params) if self.params else 0
def postFlatIter(self):
return itertools.chain([self], *[expr.postFlatIter() for expr in self.params])
def print_(self, printer, print_):
return self.fmt.format(*[print_(expr) for expr in self.params])
def tree(self, printer, tree): return [self.__class__.__name__, map(tree, self.params)]
def replaceSubExprs(self, rdict):
if self in rdict:
return rdict[self]
self.params = [param.replaceSubExprs(rdict) for param in self.params]
return self
def fixLiterals(self):
self.params = [param.fixLiterals() for param in self.params]
return self
def addCasts(self, env):
for param in self.params:
param.addCasts(env)
self.addCasts_sub(env)
def addCasts_sub(self, env): pass
def addParens(self):
for param in self.params:
param.addParens()
self.params = list(self.params) # Copy before editing, just to be extra safe
self.addParens_sub()
def addParens_sub(self): pass
def isLocalAssign(self): return isinstance(self, Assignment) and isinstance(self.params[0], Local)
def __repr__(self): # pragma: no cover
return type(self).__name__.rpartition('.')[-1] + ' ' + visitor.DefaultVisitor().visit(self)
__str__ = __repr__
class ArrayAccess(JavaExpression):
def __init__(self, *params):
if params[0].dtype == objtypes.NullTT:
# Unfortunately, Java doesn't really support array access on null constants
#So we'll just cast it to Object[] as a hack
param = makeCastExpr(objtypes.withDimInc(objtypes.ObjectTT, 1), params[0])
params = param, params[1]
self.params = list(params)
self.fmt = '{}[{}]'
@property
def dtype(self): return objtypes.withDimInc(self.params[0].dtype, -1)
def addParens_sub(self):
p0 = self.params[0]
if p0.precedence > 0 or isinstance(p0, ArrayCreation):
self.params[0] = Parenthesis(p0)
class ArrayCreation(JavaExpression):
def __init__(self, tt, *sizeargs):
self.dim = objtypes.dim(tt)
self.params = [TypeName(objtypes.withNoDim(tt))] + list(sizeargs)
self.dtype = tt
assert self.dim >= len(sizeargs) > 0
self.fmt = 'new {}' + '[{}]'*len(sizeargs) + '[]'*(self.dim-len(sizeargs))
def tree(self, printer, tree): return [self.__class__.__name__, map(tree, self.params), self.dim]
class Assignment(JavaExpression):
precedence = 21
def __init__(self, *params):
self.params = list(params)
self.fmt = '{} = {}'
@property
def dtype(self): return self.params[0].dtype
def addCasts_sub(self, env):
left, right = self.params
if not isJavaAssignable(env, right.dtype, left.dtype):
expr = makeCastExpr(left.dtype, right, fixEnv=env)
self.params = [left, expr]
def tree(self, printer, tree): return [self.__class__.__name__, map(tree, self.params), '']
_binary_ptable = ['* / %', '+ -', '<< >> >>>',
'< > <= >= instanceof', '== !=',
'&', '^', '|', '&&', '||']
binary_precedences = {}
for _ops, _val in zip(_binary_ptable, range(10,20)):
for _op in _ops.split():
binary_precedences[_op] = _val
class BinaryInfix(JavaExpression):
def __init__(self, opstr, params, dtype=None):
assert len(params) == 2
self.params = params
self.opstr = opstr
self.fmt = '{{}} {} {{}}'.format(opstr)
self._dtype = dtype
self.precedence = binary_precedences[opstr]
@property
def dtype(self): return self.params[0].dtype if self._dtype is None else self._dtype
def addParens_sub(self):
myprec = self.precedence
associative = myprec >= 15 # for now we treat +, *, etc as nonassociative due to floats
for i, p in enumerate(self.params):
if p.precedence > myprec:
self.params[i] = Parenthesis(p)
elif p.precedence == myprec and i > 0 and not associative:
self.params[i] = Parenthesis(p)
def tree(self, printer, tree): return [self.__class__.__name__, map(tree, self.params), self.opstr]
class Cast(JavaExpression):
precedence = 5
def __init__(self, *params):
self.dtype = params[0].tt
self.params = list(params)
self.fmt = '({}){}'
def fix(self, env):
tt, expr = self.dtype, self.params[1]
# "Impossible" casts are a compile error in Java.
# This can be fixed with an intermediate cast to Object
if isReferenceType(tt):
if not isJavaAssignable(env, tt, expr.dtype):
if not isJavaAssignable(env, expr.dtype, tt):
expr = makeCastExpr(objtypes.ObjectTT, expr)
self.params = [self.params[0], expr]
return self
def addCasts_sub(self, env): self.fix(env)
def addParens_sub(self):
p1 = self.params[1]
if p1.precedence > 5 or (isinstance(p1, UnaryPrefix) and p1.opstr[0] in '-+'):
self.params[1] = Parenthesis(p1)
class ClassInstanceCreation(JavaExpression):
def __init__(self, typename, tts, arguments):
self.typename, self.tts, self.params = typename, tts, arguments
self.dtype = typename.tt
def print_(self, printer, print_):
return 'new {}({})'.format(print_(self.typename), ', '.join(print_(x) for x in self.params))
def tree(self, printer, tree):
return [self.__class__.__name__, map(tree, self.params), tree(self.typename)]
def addCasts_sub(self, env):
newparams = []
for tt, expr in zip(self.tts, self.params):
if expr.dtype != tt and (ALWAYS_CAST_PARAMS or not isJavaAssignable(env, expr.dtype, tt)):
expr = makeCastExpr(tt, expr, fixEnv=env)
newparams.append(expr)
self.params = newparams
class FieldAccess(JavaExpression):
def __init__(self, primary, name, dtype, op=None, printLeft=True):
self.dtype = dtype
self.params = [primary]
self.op, self.name = op, name
self.printLeft = printLeft
# self.params, self.name = [primary], escapeString(name)
# self.fmt = ('{}.' if printLeft else '') + self.name
def print_(self, printer, print_):
if self.op is None:
name = self.name
assert name in ('length','class')
else:
cls, name, desc = self.op.target, self.op.name, self.op.desc
name = escapeString(printer.fieldName(cls, name, desc))
pre = print_(self.params[0])+'.' if self.printLeft else ''
return pre+name
def tree(self, printer, tree):
if self.op is None:
trip = None, self.name, None
else:
trip = self.op.target, self.op.name, self.op.desc
return [self.__class__.__name__, map(tree, self.params), trip, self.printLeft]
def addParens_sub(self):
p0 = self.params[0]
if p0.precedence > 0:
self.params[0] = Parenthesis(p0)
def printFloat(x, isSingle):
assert x >= 0.0 and not math.isinf(x)
suffix = 'f' if isSingle else ''
if isSingle and x > 0.0:
# Try to find more compract representation for floats, since repr treats everything as doubles
m, e = math.frexp(x)
half_ulp2 = math.ldexp(1.0, max(e - 25, -150)) # don't bother doubling when near the upper range of a given e value
half_ulp1 = (half_ulp2/2) if m == 0.5 and e >= -125 else half_ulp2
lbound, ubound = x-half_ulp1, x+half_ulp2
assert lbound < x < ubound
s = '{:g}'.format(x).replace('+','')
if lbound < float(s) < ubound: # strict ineq to avoid potential double rounding issues
return s + suffix
return repr(x) + suffix
class Literal(JavaExpression):
def __init__(self, vartype, val):
self.dtype = vartype
self.val = val
if self.dtype == objtypes.ClassTT:
self.params = [TypeName(val)]
def getStr(self):
if self.dtype == objtypes.StringTT:
return '"' + escapeString(self.val) + '"'
elif self.dtype == objtypes.IntTT:
return str(self.val)
elif self.dtype == objtypes.LongTT:
return str(self.val) + 'L'
elif self.dtype == objtypes.FloatTT or self.dtype == objtypes.DoubleTT:
return printFloat(self.val, self.dtype == objtypes.FloatTT)
elif self.dtype == objtypes.NullTT:
return 'null'
elif self.dtype == objtypes.BoolTT:
return 'true' if self.val else 'false'
def fixLiterals(self):
# From the point of view of the Java Language, there is no such thing as a negative literal.
# This replaces invalid literal values with unary minus (and division for non-finite floats)
if self.dtype == objtypes.IntTT or self.dtype == objtypes.LongTT:
if self.val < 0:
return UnaryPrefix('-', Literal(self.dtype, -self.val))
elif self.dtype == objtypes.FloatTT or self.dtype == objtypes.DoubleTT:
x = self.val
zero = Literal.DZERO if self.dtype == objtypes.DoubleTT else Literal.FZERO
if math.isnan(x):
return BinaryInfix('/', [zero, zero])
elif math.isinf(x): #+/- inf
numerator = Literal(self.dtype, math.copysign(1.0, x)).fixLiterals()
return BinaryInfix('/', [numerator, zero])
# finite negative numbers
if math.copysign(1.0, x) == -1.0:
return UnaryPrefix('-', Literal(self.dtype, math.copysign(x, 1.0)))
return self
def print_(self, printer, print_):
if self.dtype == objtypes.ClassTT:
# for printing class literals
return '{}.class'.format(print_(self.params[0]))
return self.getStr()
def tree(self, printer, tree):
result = tree(self.params[0]) if self.dtype == objtypes.ClassTT else self.getStr()
return [self.__class__.__name__, result, self.dtype]
def _key(self): return self.dtype, self.val
def __eq__(self, other): return type(self) == type(other) and self._key() == other._key()
def __ne__(self, other): return type(self) != type(other) or self._key() != other._key()
def __hash__(self): return hash(self._key())
Literal.FALSE = Literal(objtypes.BoolTT, 0)
Literal.TRUE = Literal(objtypes.BoolTT, 1)
Literal.N_ONE = Literal(objtypes.IntTT, -1)
Literal.ZERO = Literal(objtypes.IntTT, 0)
Literal.ONE = Literal(objtypes.IntTT, 1)
Literal.LZERO = Literal(objtypes.LongTT, 0)
Literal.FZERO = Literal(objtypes.FloatTT, 0.0)
Literal.DZERO = Literal(objtypes.DoubleTT, 0.0)
Literal.NULL = Literal(objtypes.NullTT, None)
_init_d = {objtypes.BoolTT: Literal.FALSE,
objtypes.IntTT: Literal.ZERO,
objtypes.LongTT: Literal.LZERO,
objtypes.FloatTT: Literal.FZERO,
objtypes.DoubleTT: Literal.DZERO}
def dummyLiteral(tt):
return _init_d.get(tt, Literal.NULL)
class Local(JavaExpression):
def __init__(self, vartype, namefunc):
self.dtype = vartype
self.name = None
self.func = namefunc
def print_(self, printer, print_):
if self.name is None:
self.name = self.func(self)
return self.name
def tree(self, printer, tree): return [self.__class__.__name__, self.print_(None, None)]
class MethodInvocation(JavaExpression):
def __init__(self, left, name, tts, arguments, op, dtype):
if left is None:
self.params = arguments
else:
self.params = [left] + arguments
self.hasLeft = (left is not None)
self.dtype = dtype
self.name = name
self.tts = tts
self.op = op # keep around for future reference and new merging
def print_(self, printer, print_):
cls, name, desc = self.op.target, self.op.name, self.op.desc
if name != self.name:
assert name == '<init>'
name = self.name
else:
name = escapeString(printer.methodName(cls, name, desc))
if self.hasLeft:
left, arguments = self.params[0], self.params[1:]
return '{}.{}({})'.format(print_(left), name, ', '.join(print_(x) for x in arguments))
else:
arguments = self.params
return '{}({})'.format(name, ', '.join(print_(x) for x in arguments))
def tree(self, printer, tree):
trip = self.op.target, self.op.name, self.op.desc
return [self.__class__.__name__, map(tree, self.params), trip, self.name, self.hasLeft]
def addCasts_sub(self, env):
newparams = []
for tt, expr in zip(self.tts, self.params):
if expr.dtype != tt and (ALWAYS_CAST_PARAMS or not isJavaAssignable(env, expr.dtype, tt)):
expr = makeCastExpr(tt, expr, fixEnv=env)
newparams.append(expr)
self.params = newparams
def addParens_sub(self):
if self.hasLeft:
p0 = self.params[0]
if p0.precedence > 0:
self.params[0] = Parenthesis(p0)
class Parenthesis(JavaExpression):
def __init__(self, param):
self.params = [param]
self.fmt = '({})'
@property
def dtype(self): return self.params[0].dtype
class Ternary(JavaExpression):
precedence = 20
def __init__(self, *params):
self.params = list(params)
self.fmt = '{} ? {} : {}'
@property
def dtype(self): return self.params[1].dtype
def addParens_sub(self):
# Add unecessary parenthesis to complex conditions for readability
if self.params[0].precedence >= 20 or self.params[0].complexity() > 0:
self.params[0] = Parenthesis(self.params[0])
if self.params[2].precedence > 20:
self.params[2] = Parenthesis(self.params[2])
class TypeName(JavaExpression):
def __init__(self, tt):
self.dtype = None
self.tt = tt
def print_(self, printer, print_):
name = objtypes.className(self.tt)
if name is not None:
name = printer.className(name)
name = escapeString(name.replace('/','.'))
if name.rpartition('.')[0] == 'java.lang':
name = name.rpartition('.')[2]
else:
name = objtypes.primName(self.tt)
s = name + '[]'*objtypes.dim(self.tt)
return s
def tree(self, printer, tree): return [self.__class__.__name__, self.tt]
def complexity(self): return -1 # exprs which have this as a param won't be bumped up to 1 uncessarily
class CatchTypeNames(JavaExpression): # Used for caught exceptions, which can have multiple types specified
def __init__(self, env, tts):
assert(tts and not any(objtypes.dim(tt) for tt in tts)) # at least one type, no array types
self.tnames = map(TypeName, tts)
self.dtype = objtypes.commonSupertype(env, tts)
def print_(self, printer, print_):
return ' | '.join(print_(tn) for tn in self.tnames)
def tree(self, printer, tree): return [self.__class__.__name__, map(tree, self.tnames)]
class UnaryPrefix(JavaExpression):
precedence = 5
def __init__(self, opstr, param, dtype=None):
self.params = [param]
self.opstr = opstr
self.fmt = opstr + '{}'
self._dtype = dtype
@property
def dtype(self): return self.params[0].dtype if self._dtype is None else self._dtype
def addParens_sub(self):
p0 = self.params[0]
if p0.precedence > 5 or (isinstance(p0, UnaryPrefix) and p0.opstr[0] == self.opstr[0]):
self.params[0] = Parenthesis(p0)
def tree(self, printer, tree): return ['Unary', map(tree, self.params), self.opstr, False]
class Dummy(JavaExpression):
def __init__(self, fmt, params, isNew=False, dtype=None):
self.params = params
self.fmt = fmt
self.isNew = isNew
self.dtype = dtype

View File

@ -0,0 +1,138 @@
from ..ssa import objtypes
from . import ast
from .stringescape import escapeString as escape
class Comments(object):
def __init__(self):
self.lines = []
def add(self, s):
self.lines.extend(s.strip('\n').split('\n'))
def print_(self, printer, print_):
return ''.join(map('// {}\n'.format, self.lines))
class MethodDef(object):
def __init__(self, class_, flags, name, desc, retType, paramDecls, body):
self.flagstr = flags + ' ' if flags else ''
self.retType, self.paramDecls = retType, paramDecls
self.body = body
self.comments = Comments()
self.triple = class_.name, name, desc
self.throws = None
if name == '<clinit>':
self.isStaticInit, self.isConstructor = True, False
elif name == '<init>':
self.isStaticInit, self.isConstructor = False, True
self.clsname = ast.TypeName(objtypes.TypeTT(class_.name, 0))
else:
self.isStaticInit, self.isConstructor = False, False
def print_(self, printer, print_):
header = print_(self.comments)
argstr = ', '.join(print_(decl) for decl in self.paramDecls)
if self.isStaticInit:
header += 'static'
elif self.isConstructor:
name = print_(self.clsname).rpartition('.')[-1]
header += '{}{}({})'.format(self.flagstr, name, argstr)
else:
name = printer.methodName(*self.triple)
header += '{}{} {}({})'.format(self.flagstr, print_(self.retType), escape(name), argstr)
if self.throws is not None:
header += ' throws ' + print_(self.throws)
if self.body is None:
if 'abstract' not in self.flagstr and 'native' not in self.flagstr:
# Create dummy body for decompiler error
return header + ' {/*error*/throw null;}\n'
return header + ';\n'
else:
return header + ' ' + print_(self.body)
def tree(self, printer, tree):
return {
'triple': self.triple,
'flags': self.flagstr.split(),
'ret': tree(self.retType),
'params': map(tree, self.paramDecls),
'comments': self.comments.lines,
'body': tree(self.body),
'throws': tree(self.throws),
}
class FieldDef(object):
def __init__(self, flags, type_, class_, name, desc, expr=None):
self.flagstr = flags + ' ' if flags else ''
self.type_ = type_
self.name = name
self.expr = None if expr is None else ast.makeCastExpr(type_.tt, expr)
self.triple = class_.name, name, desc
def print_(self, printer, print_):
name = escape(printer.fieldName(*self.triple))
if self.expr is not None:
return '{}{} {} = {};'.format(self.flagstr, print_(self.type_), name, print_(self.expr))
return '{}{} {};'.format(self.flagstr, print_(self.type_), name)
def tree(self, printer, tree):
return {
'triple': self.triple,
'type': tree(self.type_),
'flags': self.flagstr.split(),
'expr': tree(self.expr),
}
class ClassDef(object):
def __init__(self, flags, isInterface, name, superc, interfaces, fields, methods):
self.flagstr = flags + ' ' if flags else ''
self.isInterface = isInterface
self.name = ast.TypeName(objtypes.TypeTT(name,0))
self.super = ast.TypeName(objtypes.TypeTT(superc,0)) if superc is not None else None
self.interfaces = [ast.TypeName(objtypes.TypeTT(iname,0)) for iname in interfaces]
self.fields = fields
self.methods = methods
if superc == 'java/lang/Object':
self.super = None
def print_(self, printer, print_):
contents = ''
if self.fields:
contents = '\n'.join(print_(x) for x in self.fields)
if self.methods:
if contents:
contents += '\n\n' # extra line to divide fields and methods
contents += '\n\n'.join(print_(x) for x in self.methods)
indented = [' '+line for line in contents.splitlines()]
name = print_(self.name).rpartition('.')[-1]
defname = 'interface' if self.isInterface else 'class'
header = '{}{} {}'.format(self.flagstr, defname, name)
if self.super:
header += ' extends ' + print_(self.super)
if self.interfaces:
if self.isInterface:
assert self.super is None
header += ' extends ' + ', '.join(print_(x) for x in self.interfaces)
else:
header += ' implements ' + ', '.join(print_(x) for x in self.interfaces)
lines = [header + ' {'] + indented + ['}']
return '\n'.join(lines) + '\n'
# Experimental - don't use!
def tree(self, printer, tree):
return {
'rawname': objtypes.className(self.name.tt),
'name': tree(self.name),
'super': tree(self.super),
'flags': self.flagstr.split(),
'isInterface': self.isInterface,
'interfaces': map(tree, self.interfaces),
'fields': map(tree, self.fields),
'methods': map(tree, self.methods),
}

View File

@ -0,0 +1,322 @@
from .. import opnames
from ..namegen import LabelGen
from ..ssa import objtypes, ssa_jumps, ssa_ops, ssa_types
from ..verifier.descriptors import parseFieldDescriptor, parseMethodDescriptor
from . import ast
from .setree import SEBlockItem, SEIf, SEScope, SESwitch, SETry, SEWhile
# prefixes for name generation
_prefix_map = {objtypes.IntTT:'i', objtypes.LongTT:'j',
objtypes.FloatTT:'f', objtypes.DoubleTT:'d',
objtypes.BoolTT:'b', objtypes.StringTT:'s'}
_ssaToTT = {ssa_types.SSA_INT:objtypes.IntTT, ssa_types.SSA_LONG:objtypes.LongTT,
ssa_types.SSA_FLOAT:objtypes.FloatTT, ssa_types.SSA_DOUBLE:objtypes.DoubleTT}
class VarInfo(object):
def __init__(self, method, blocks, namegen):
self.env = method.class_.env
self.labelgen = LabelGen().next
returnTypes = parseMethodDescriptor(method.descriptor, unsynthesize=False)[-1]
self.return_tt = objtypes.verifierToSynthetic(returnTypes[0]) if returnTypes else None
self.clsname = method.class_.name
self._namegen = namegen
self._uninit_vars = {}
self._vars = {}
self._tts = {}
for block in blocks:
for var, uc in block.unaryConstraints.items():
if var.type == ssa_types.SSA_OBJECT:
tt = uc.getSingleTType() # temp hack
if uc.types.isBoolOrByteArray():
tt = objtypes.TypeTT(objtypes.BExpr, objtypes.dim(tt)+1)
# assert (objtypes.BoolTT[0], tt[1]) in uc.types.exact
else:
tt = _ssaToTT[var.type]
self._tts[var] = tt
def _nameCallback(self, expr):
prefix = _prefix_map.get(expr.dtype, 'a')
return self._namegen.getPrefix(prefix)
def _newVar(self, var, num, isCast):
tt = self._tts[var]
if var.const is not None and not isCast:
return ast.Literal(tt, var.const)
if var.name:
# important to not add num when it is 0, since we currently
# use var names to force 'this'
temp = '{}_{}'.format(var.name, num) if num else var.name
if isCast:
temp += 'c'
namefunc = lambda expr:temp
else:
namefunc = self._nameCallback
result = ast.Local(tt, namefunc)
# merge all variables of uninitialized type to simplify fixObjectCreations in javamethod.py
if var.uninit_orig_num is not None and not isCast:
result = self._uninit_vars.setdefault(var.uninit_orig_num, result)
return result
def var(self, node, var, isCast=False):
key = node, var, isCast
try:
return self._vars[key]
except KeyError:
new = self._newVar(key[1], key[0].num, key[2])
self._vars[key] = new
return new
def customVar(self, tt, prefix): # for use with ignored exceptions
namefunc = lambda expr: self._namegen.getPrefix(prefix)
return ast.Local(tt, namefunc)
#########################################################################################
_math_types = (ssa_ops.IAdd, ssa_ops.IDiv, ssa_ops.IMul, ssa_ops.IRem, ssa_ops.ISub)
_math_types += (ssa_ops.IAnd, ssa_ops.IOr, ssa_ops.IShl, ssa_ops.IShr, ssa_ops.IUshr, ssa_ops.IXor)
_math_types += (ssa_ops.FAdd, ssa_ops.FDiv, ssa_ops.FMul, ssa_ops.FRem, ssa_ops.FSub)
_math_symbols = dict(zip(_math_types, '+ / * % - & | << >> >>> ^ + / * % -'.split()))
def _convertJExpr(op, getExpr, clsname):
params = [getExpr(var) for var in op.params]
assert None not in params
expr = None
# Have to do this one seperately since it isn't an expression statement
if isinstance(op, ssa_ops.Throw):
return ast.ThrowStatement(params[0])
if isinstance(op, _math_types):
opdict = _math_symbols
expr = ast.BinaryInfix(opdict[type(op)], params)
elif isinstance(op, ssa_ops.ArrLength):
expr = ast.FieldAccess(params[0], 'length', objtypes.IntTT)
elif isinstance(op, ssa_ops.ArrLoad):
expr = ast.ArrayAccess(*params)
elif isinstance(op, ssa_ops.ArrStore):
expr = ast.ArrayAccess(params[0], params[1])
expr = ast.Assignment(expr, params[2])
elif isinstance(op, ssa_ops.CheckCast):
expr = ast.Cast(ast.TypeName(op.target_tt), params[0])
elif isinstance(op, ssa_ops.Convert):
expr = ast.makeCastExpr(_ssaToTT[op.target], params[0])
elif isinstance(op, (ssa_ops.FCmp, ssa_ops.ICmp)):
boolt = objtypes.BoolTT
cn1, c0, c1 = ast.Literal.N_ONE, ast.Literal.ZERO, ast.Literal.ONE
ascend = isinstance(op, ssa_ops.ICmp) or op.NaN_val == 1
if ascend:
expr = ast.Ternary(ast.BinaryInfix('<',params,boolt), cn1, ast.Ternary(ast.BinaryInfix('==',params,boolt), c0, c1))
else:
assert op.NaN_val == -1
expr = ast.Ternary(ast.BinaryInfix('>',params,boolt), c1, ast.Ternary(ast.BinaryInfix('==',params,boolt), c0, cn1))
elif isinstance(op, ssa_ops.FieldAccess):
dtype = objtypes.verifierToSynthetic(parseFieldDescriptor(op.desc, unsynthesize=False)[0])
if op.instruction[0] in (opnames.GETSTATIC, opnames.PUTSTATIC):
printLeft = (op.target != clsname) # Don't print classname if it is a static field in current class
tt = objtypes.TypeTT(op.target, 0) # Doesn't handle arrays, but they don't have any fields anyway
expr = ast.FieldAccess(ast.TypeName(tt), op.name, dtype, op, printLeft=printLeft)
else:
expr = ast.FieldAccess(params[0], op.name, dtype, op)
if op.instruction[0] in (opnames.PUTFIELD, opnames.PUTSTATIC):
expr = ast.Assignment(expr, params[-1])
elif isinstance(op, ssa_ops.FNeg):
expr = ast.UnaryPrefix('-', params[0])
elif isinstance(op, ssa_ops.InstanceOf):
args = [params[0], ast.TypeName(op.target_tt)]
expr = ast.BinaryInfix('instanceof', args, dtype=objtypes.BoolTT)
elif isinstance(op, ssa_ops.Invoke):
vtypes, rettypes = parseMethodDescriptor(op.desc, unsynthesize=False)
tt_types = objtypes.verifierToSynthetic_seq(vtypes)
ret_type = objtypes.verifierToSynthetic(rettypes[0]) if rettypes else None
target_tt = op.target_tt
if objtypes.dim(target_tt) and op.name == "clone": # In Java, T[].clone returns T[] rather than Object
ret_type = target_tt
if op.instruction[0] == opnames.INVOKEINIT and op.isThisCtor:
name = 'this' if (op.target == clsname) else 'super'
expr = ast.MethodInvocation(None, name, tt_types, params[1:], op, ret_type)
elif op.instruction[0] == opnames.INVOKESTATIC: # TODO - fix this for special super calls
expr = ast.MethodInvocation(ast.TypeName(target_tt), op.name, [None]+tt_types, params, op, ret_type)
else:
expr = ast.MethodInvocation(params[0], op.name, [target_tt]+tt_types, params[1:], op, ret_type)
elif isinstance(op, ssa_ops.InvokeDynamic):
vtypes, rettypes = parseMethodDescriptor(op.desc, unsynthesize=False)
ret_type = objtypes.verifierToSynthetic(rettypes[0]) if rettypes else None
fmt = '/*invokedynamic*/'
if ret_type is not None:
fmt += '{{{}}}'.format(len(params))
params.append(ast.dummyLiteral(ret_type))
expr = ast.Dummy(fmt, params, dtype=ret_type)
elif isinstance(op, ssa_ops.Monitor):
fmt = '/*monexit({})*/' if op.exit else '/*monenter({})*/'
expr = ast.Dummy(fmt, params)
elif isinstance(op, ssa_ops.MultiNewArray):
expr = ast.ArrayCreation(op.tt, *params)
elif isinstance(op, ssa_ops.New):
expr = ast.Dummy('/*<unmerged new> {}*/', [ast.TypeName(op.tt)], isNew=True)
elif isinstance(op, ssa_ops.NewArray):
expr = ast.ArrayCreation(op.tt, params[0])
elif isinstance(op, ssa_ops.Truncate):
tt = {(True,16): objtypes.ShortTT, (False,16): objtypes.CharTT, (True,8): objtypes.ByteTT}[op.signed, op.width]
expr = ast.Cast(ast.TypeName(tt), params[0])
if op.rval is not None and expr:
expr = ast.Assignment(getExpr(op.rval), expr)
if expr is None: # Temporary hack
if isinstance(op, (ssa_ops.TryReturn, ssa_ops.ExceptionPhi, ssa_ops.MagicThrow)):
return None # Don't print out anything
return ast.ExpressionStatement(expr)
#########################################################################################
def _createASTBlock(info, endk, node):
getExpr = lambda var: info.var(node, var)
op2expr = lambda op: _convertJExpr(op, getExpr, info.clsname)
block = node.block
if block is not None:
split_ind = 0
if isinstance(block.jump, ssa_jumps.OnException):
# find index of first throwing instruction, so we can insert eassigns before it later
assert isinstance(block.lines[-1], ssa_ops.ExceptionPhi)
split_ind = block.lines.index(block.lines[-1].params[0].origin)
lines_before = filter(None, map(op2expr, block.lines[:split_ind]))
lines_after = filter(None, map(op2expr, block.lines[split_ind:]))
else:
lines_before, lines_after = [], []
# Kind of hackish: If the block ends in a cast and hence it is not known to always
# succeed, assign the results of the cast rather than passing through the variable
# unchanged. The cast will actually be second to last in block.lines due to the ephi
outreplace = {}
if block and len(block.lines) >= 2:
temp_op = block.lines[-2]
if lines_after and isinstance(temp_op, ssa_ops.CheckCast):
assert isinstance(lines_after[-1].expr, ast.Cast)
var = temp_op.params[0]
cexpr = lines_after[-1].expr
lhs = info.var(node, var, True)
assert lhs != cexpr.params[1]
lines_after[-1].expr = ast.Assignment(lhs, cexpr)
nvar = outreplace[var] = lines_after[-1].expr.params[0]
nvar.dtype = cexpr.dtype
eassigns = []
nassigns = []
for n2 in node.successors:
assert (n2 in node.outvars) != (n2 in node.eassigns)
if n2 in node.eassigns:
for outv, inv in zip(node.eassigns[n2], n2.invars):
if outv is None: # this is how we mark the thrown exception, which
# obviously doesn't get an explicit assignment statement
continue
expr = ast.Assignment(info.var(n2, inv), info.var(node, outv))
if expr.params[0] != expr.params[1]:
eassigns.append(ast.ExpressionStatement(expr))
else:
for outv, inv in zip(node.outvars[n2], n2.invars):
right = outreplace.get(outv, info.var(node, outv))
expr = ast.Assignment(info.var(n2, inv), right)
if expr.params[0] != expr.params[1]:
nassigns.append(ast.ExpressionStatement(expr))
# Need to put exception assignments before first throwing statement
# While normal assignments must come last as they may depend on it
statements = lines_before + eassigns + lines_after + nassigns
norm_successors = node.normalSuccessors()
jump = None if block is None else block.jump
if isinstance(jump, (ssa_jumps.Rethrow, ssa_jumps.Return)):
assert not norm_successors
assert not node.eassigns and not node.outvars
if isinstance(jump, ssa_jumps.Rethrow):
param = info.var(node, jump.params[-1])
statements.append(ast.ThrowStatement(param))
else:
if len(jump.params) > 0:
param = info.var(node, jump.params[0])
statements.append(ast.ReturnStatement(param, info.return_tt))
else:
statements.append(ast.ReturnStatement())
breakKey, jumpKey = endk, None
elif len(norm_successors) == 0:
assert isinstance(jump, ssa_jumps.OnException)
breakKey, jumpKey = endk, None
elif len(norm_successors) == 1: # normal successors
breakKey, jumpKey = endk, norm_successors[0]._key
else: # case of if and switch jumps handled in parent scope
assert len(norm_successors) > 1
breakKey, jumpKey = endk, endk
new = ast.StatementBlock(info.labelgen, node._key, breakKey, statements, jumpKey)
assert None not in statements
return new
_cmp_strs = dict(zip(('eq','ne','lt','ge','gt','le'), "== != < >= > <=".split()))
def _createASTSub(info, current, ftitem, forceUnlabled=False):
begink = current.entryBlock._key
endk = ftitem.entryBlock._key if ftitem is not None else None
if isinstance(current, SEBlockItem):
return _createASTBlock(info, endk, current.node)
elif isinstance(current, SEScope):
ftitems = current.items[1:] + [ftitem]
parts = [_createASTSub(info, item, newft) for item, newft in zip(current.items, ftitems)]
return ast.StatementBlock(info.labelgen, begink, endk, parts, endk, labelable=(not forceUnlabled))
elif isinstance(current, SEWhile):
parts = [_createASTSub(info, scope, current, True) for scope in current.getScopes()]
return ast.WhileStatement(info.labelgen, begink, endk, tuple(parts))
elif isinstance(current, SETry):
assert len(current.getScopes()) == 2
parts = [_createASTSub(info, scope, ftitem, True) for scope in current.getScopes()]
catchnode = current.getScopes()[-1].entryBlock
declt = ast.CatchTypeNames(info.env, current.toptts)
if current.catchvar is None: # exception is ignored and hence not referred to by the graph, so we need to make our own
catchvar = info.customVar(declt.dtype, 'ignoredException')
else:
catchvar = info.var(catchnode, current.catchvar)
decl = ast.VariableDeclarator(declt, catchvar)
pairs = [(decl, parts[1])]
return ast.TryStatement(info.labelgen, begink, endk, parts[0], pairs)
# Create a fake key to represent the beginning of the conditional statement itself
# doesn't matter what it is as long as it's unique
midk = begink + (-1,)
node = current.head.node
jump = node.block.jump
if isinstance(current, SEIf):
parts = [_createASTSub(info, scope, ftitem, True) for scope in current.getScopes()]
cmp_str = _cmp_strs[jump.cmp]
exprs = [info.var(node, var) for var in jump.params]
ifexpr = ast.BinaryInfix(cmp_str, exprs, objtypes.BoolTT)
new = ast.IfStatement(info.labelgen, midk, endk, ifexpr, tuple(parts))
elif isinstance(current, SESwitch):
ftitems = current.ordered[1:] + [ftitem]
parts = [_createASTSub(info, item, newft, True) for item, newft in zip(current.ordered, ftitems)]
for part in parts:
part.breakKey = endk # createSub will assume break should be ft, which isn't the case with switch statements
expr = info.var(node, jump.params[0])
pairs = zip(current.ordered_keysets, parts)
new = ast.SwitchStatement(info.labelgen, midk, endk, expr, pairs)
# bundle head and if together so we can return as single statement
headscope = _createASTBlock(info, midk, node)
assert headscope.jumpKey is midk
return ast.StatementBlock(info.labelgen, begink, endk, [headscope, new], endk)
def createAST(method, ssagraph, seroot, namegen):
info = VarInfo(method, ssagraph.blocks, namegen)
astroot = _createASTSub(info, seroot, None)
return astroot, info

View File

@ -0,0 +1,179 @@
import collections
from .. import graph_util
from ..ssa import objtypes
from ..ssa.objtypes import BExpr, BoolTT, ByteTT, CharTT, IntTT, ShortTT
from . import ast
# Class union-find data structure except that we don't bother with weighting trees and singletons are implicit
# Also, booleans are forced to be seperate roots
FORCED_ROOTS = True, False
class UnionFind(object):
def __init__(self):
self.d = {}
def find(self, x):
if x not in self.d:
return x
path = [x]
while path[-1] in self.d:
path.append(self.d[path[-1]])
root = path.pop()
for y in path:
self.d[y] = root
return root
def union(self, x, x2):
if x is None or x2 is None:
return
root1, root2 = self.find(x), self.find(x2)
if root2 in FORCED_ROOTS:
root1, root2 = root2, root1
if root1 != root2 and root2 not in FORCED_ROOTS:
self.d[root2] = root1
##############################################################
def visitStatementTree(scope, callback, catchcb=None):
for item in scope.statements:
for sub in item.getScopes():
visitStatementTree(sub, callback, catchcb)
if item.expr is not None:
callback(item, item.expr)
if catchcb is not None and isinstance(item, ast.TryStatement):
for pair in item.pairs:
catchcb(pair[0])
int_tags = frozenset(map(objtypes.baset, [IntTT, ShortTT, CharTT, ByteTT, BoolTT]))
array_tags = frozenset(map(objtypes.baset, [ByteTT, BoolTT]) + [objtypes.BExpr])
# Fix int/bool and byte[]/bool[] vars
def boolizeVars(root, arg_vars):
varlist = []
sets = UnionFind()
def visitExpr(expr, forceExact=False):
# see if we have to merge
if isinstance(expr, ast.Assignment) or isinstance(expr, ast.BinaryInfix) and expr.opstr in ('==','!=','&','|','^'):
subs = [visitExpr(param) for param in expr.params]
sets.union(*subs) # these operators can work on either type but need the same type on each side
elif isinstance(expr, ast.ArrayAccess):
sets.union(False, visitExpr(expr.params[1])) # array index is int only
elif isinstance(expr, ast.BinaryInfix) and expr.opstr in ('* / % + - << >> >>>'):
sets.union(False, visitExpr(expr.params[0])) # these operators are int only
sets.union(False, visitExpr(expr.params[1]))
if isinstance(expr, ast.Local):
tag, dim = objtypes.baset(expr.dtype), objtypes.dim(expr.dtype)
if (dim == 0 and tag in int_tags) or (dim > 0 and tag in array_tags):
# the only "unknown" vars are bexpr[] and ints. All else have fixed types
if forceExact or (tag != BExpr and tag != objtypes.baset(IntTT)):
sets.union(tag == objtypes.baset(BoolTT), expr)
varlist.append(expr)
return sets.find(expr)
elif isinstance(expr, ast.Literal):
if expr.dtype == IntTT and expr.val not in (0,1):
return False
return None # if val is 0 or 1, or the literal is a null, it is freely convertable
elif isinstance(expr, ast.Assignment) or (isinstance(expr, ast.BinaryInfix) and expr.opstr in ('&','|','^')):
return subs[0]
elif isinstance(expr, (ast.ArrayAccess, ast.Parenthesis, ast.UnaryPrefix)):
return visitExpr(expr.params[0])
elif expr.dtype is not None and objtypes.baset(expr.dtype) != BExpr:
return expr.dtype[0] == objtypes.baset(BoolTT)
return None
def visitStatement(item, expr):
root = visitExpr(expr)
if isinstance(item, ast.ReturnStatement):
forced_val = (objtypes.baset(item.tt) == objtypes.baset(BoolTT))
sets.union(forced_val, root)
elif isinstance(item, ast.SwitchStatement):
sets.union(False, root) # Switch must take an int, not a bool
for expr in arg_vars:
visitExpr(expr, forceExact=True)
visitStatementTree(root, callback=visitStatement)
# Fix the propagated types
for var in set(varlist):
tag, dim = objtypes.baset(var.dtype), objtypes.dim(var.dtype)
assert tag in int_tags or (dim>0 and tag == BExpr)
# make everything bool which is not forced to int
if sets.find(var) != False:
var.dtype = objtypes.withDimInc(BoolTT, dim)
elif dim > 0:
var.dtype = objtypes.withDimInc(ByteTT, dim)
# Fix everything else back up
def fixExpr(item, expr):
for param in expr.params:
fixExpr(None, param)
if isinstance(expr, ast.Assignment):
left, right = expr.params
if objtypes.baset(left.dtype) in int_tags and objtypes.dim(left.dtype) == 0:
if not ast.isPrimativeAssignable(right.dtype, left.dtype):
expr.params = [left, ast.makeCastExpr(left.dtype, right)]
elif isinstance(expr, ast.BinaryInfix):
a, b = expr.params
# shouldn't need to do anything here for arrays
if expr.opstr in '== != & | ^' and a.dtype == BoolTT or b.dtype == BoolTT:
expr.params = [ast.makeCastExpr(BoolTT, v) for v in expr.params]
visitStatementTree(root, callback=fixExpr)
# Fix vars of interface/object type
# TODO: do this properly
def interfaceVars(env, root, arg_vars):
varlist = []
consts = {}
assigns = collections.defaultdict(list)
def isInterfaceVar(expr):
if not isinstance(expr, ast.Local) or not objtypes.isBaseTClass(expr.dtype):
return False
if objtypes.className(expr.dtype) == objtypes.className(objtypes.ObjectTT):
return True
return env.isInterface(objtypes.className(expr.dtype))
def updateConst(var, tt):
varlist.append(var)
if var not in consts:
consts[var] = tt
else:
consts[var] = objtypes.commonSupertype(env, [consts[var], tt])
def visitStatement(item, expr):
if isinstance(expr, ast.Assignment) and objtypes.isBaseTClass(expr.dtype):
left, right = expr.params
if isInterfaceVar(left):
if isInterfaceVar(right):
assigns[left].append(right)
varlist.append(right)
varlist.append(left)
else:
updateConst(left, right.dtype)
def visitCatchDecl(decl):
updateConst(decl.local, decl.typename.dtype)
for expr in arg_vars:
if objtypes.isBaseTClass(expr.dtype):
updateConst(expr, expr.dtype)
visitStatementTree(root, callback=visitStatement, catchcb=visitCatchDecl)
# Now calculate actual types and fix
newtypes = {}
# visit variables in topological order. Doesn't handle case of loops, but this is a temporary hack anyway
order = graph_util.topologicalSort(varlist, lambda v:assigns[v])
for var in order:
assert var not in newtypes
tts = [newtypes.get(right, objtypes.ObjectTT) for right in assigns[var]]
if var in consts:
tts.append(consts[var])
newtypes[var] = newtype = objtypes.commonSupertype(env, tts)
if newtype != objtypes.ObjectTT and newtype != var.dtype and newtype != objtypes.NullTT:
# assert objtypes.baset(var.dtype) == objtypes.baset(objtypes.ObjectTT)
var.dtype = newtype

View File

@ -0,0 +1,227 @@
from collections import defaultdict as ddict
from .. import graph_util
from ..ssa import objtypes
from . import ast
def flattenDict(replace):
for k in list(replace):
while replace[k] in replace:
replace[k] = replace[replace[k]]
# The basic block in our temporary CFG
# instead of code, it merely contains a list of defs and uses
# This is an extended basic block, i.e. it only terminates in a normal jump(s).
# exceptions can be thrown from various points within the block
class DUBlock(object):
def __init__(self, key):
self.key = key
self.caught_excepts = ()
self.lines = [] # 3 types of lines: ('use', var), ('def', (var, var2_opt)), or ('canthrow', None)
self.e_successors = []
self.n_successors = []
self.vars = None # vars used or defined within the block. Does NOT include caught exceptions
def canThrow(self): return ('canthrow', None) in self.lines
def recalcVars(self):
self.vars = set()
for line_t, data in self.lines:
if line_t == 'use':
self.vars.add(data)
elif line_t == 'def':
self.vars.add(data[0])
if data[1] is not None:
self.vars.add(data[1])
def replace(self, replace):
if not self.vars.isdisjoint(replace):
newlines = []
for line_t, data in self.lines:
if line_t == 'use':
data = replace.get(data, data)
elif line_t == 'def':
data = replace.get(data[0], data[0]), replace.get(data[1], data[1])
newlines.append((line_t, data))
self.lines = newlines
for k, v in replace.items():
if k in self.vars:
self.vars.remove(k)
self.vars.add(v)
def simplify(self):
# try to prune redundant instructions
last = None
newlines = []
for line in self.lines:
if line[0] == 'def':
if line[1][0] == line[1][1]:
continue
elif line == last:
continue
newlines.append(line)
last = line
self.lines = newlines
self.recalcVars()
def varOrNone(expr):
return expr if isinstance(expr, ast.Local) else None
def canThrow(expr):
if isinstance(expr, (ast.ArrayAccess, ast.ArrayCreation, ast.Cast, ast.ClassInstanceCreation, ast.FieldAccess, ast.MethodInvocation)):
return True
if isinstance(expr, ast.BinaryInfix) and expr.opstr in ('/','%'): # check for possible division by 0
return expr.dtype not in (objtypes.FloatTT, objtypes.DoubleTT)
return False
def visitExpr(expr, lines):
if expr is None:
return
if isinstance(expr, ast.Local):
lines.append(('use', expr))
if isinstance(expr, ast.Assignment):
lhs, rhs = map(varOrNone, expr.params)
# with assignment we need to only visit LHS if it isn't a local in order to avoid spurious uses
# also, we need to visit RHS before generating the def
if lhs is None:
visitExpr(expr.params[0], lines)
visitExpr(expr.params[1], lines)
if lhs is not None:
lines.append(('def', (lhs, rhs)))
else:
for param in expr.params:
visitExpr(param, lines)
if canThrow(expr):
lines.append(('canthrow', None))
class DUGraph(object):
def __init__(self):
self.blocks = []
self.entry = None
def makeBlock(self, key, break_dict, caught_except, myexcept_parents):
block = DUBlock(key)
self.blocks.append(block)
for parent in break_dict[block.key]:
parent.n_successors.append(block)
del break_dict[block.key]
assert (myexcept_parents is None) == (caught_except is None)
if caught_except is not None: # this is the head of a catch block:
block.caught_excepts = (caught_except,)
for parent in myexcept_parents:
parent.e_successors.append(block)
return block
def finishBlock(self, block, catch_stack):
# register exception handlers for completed old block and calculate var set
assert(block.vars is None) # make sure it wasn't finished twice
if block.canThrow():
for clist in catch_stack:
clist.append(block)
block.recalcVars()
def visitScope(self, scope, break_dict, catch_stack, caught_except=None, myexcept_parents=None, head_block=None):
# catch_stack is copy on modify
if head_block is None:
head_block = block = self.makeBlock(scope.continueKey, break_dict, caught_except, myexcept_parents)
else:
block = head_block
for stmt in scope.statements:
if isinstance(stmt, (ast.ExpressionStatement, ast.ThrowStatement, ast.ReturnStatement)):
visitExpr(stmt.expr, block.lines)
if isinstance(stmt, ast.ThrowStatement):
block.lines.append(('canthrow', None))
continue
# compound statements
assert stmt.continueKey is not None
if isinstance(stmt, (ast.IfStatement, ast.SwitchStatement)):
visitExpr(stmt.expr, block.lines)
if isinstance(stmt, ast.SwitchStatement):
ft = not stmt.hasDefault()
else:
ft = len(stmt.getScopes()) == 1
for sub in stmt.getScopes():
break_dict[sub.continueKey].append(block)
self.visitScope(sub, break_dict, catch_stack)
if ft:
break_dict[stmt.breakKey].append(block)
elif isinstance(stmt, ast.WhileStatement):
if stmt.expr != ast.Literal.TRUE: # while(cond)
assert stmt.breakKey is not None
self.finishBlock(block, catch_stack)
block = self.makeBlock(stmt.continueKey, break_dict, None, None)
visitExpr(stmt.expr, block.lines)
break_dict[stmt.breakKey].append(block)
break_dict[stmt.continueKey].append(block)
body_block = self.visitScope(stmt.getScopes()[0], break_dict, catch_stack)
continue_target = body_block if stmt.expr == ast.Literal.TRUE else block
for parent in break_dict[stmt.continueKey]:
parent.n_successors.append(continue_target)
del break_dict[stmt.continueKey]
elif isinstance(stmt, ast.TryStatement):
new_stack = catch_stack + [[] for _ in stmt.pairs]
break_dict[stmt.tryb.continueKey].append(block)
self.visitScope(stmt.tryb, break_dict, new_stack)
for cdecl, catchb in stmt.pairs:
parents = new_stack.pop()
self.visitScope(catchb, break_dict, catch_stack, cdecl.local, parents)
assert new_stack == catch_stack
else:
assert isinstance(stmt, ast.StatementBlock)
break_dict[stmt.continueKey].append(block)
self.visitScope(stmt, break_dict, catch_stack, head_block=block)
if not isinstance(stmt, ast.StatementBlock): # if we passed it to subscope, it will be finished in the subcall
self.finishBlock(block, catch_stack)
if stmt.breakKey is not None: # start new block after return from compound statement
block = self.makeBlock(stmt.breakKey, break_dict, None, None)
else:
block = None # should never be accessed anyway if we're exiting abruptly
if scope.jumpKey is not None:
break_dict[scope.jumpKey].append(block)
if block is not None:
self.finishBlock(block, catch_stack)
return head_block # head needs to be returned in case of loops so we can fix up backedges
def makeCFG(self, root):
break_dict = ddict(list)
self.visitScope(root, break_dict, [])
self.entry = self.blocks[0] # entry point should always be first block generated
reached = graph_util.topologicalSort([self.entry], lambda block:(block.n_successors + block.e_successors))
# if len(reached) != len(self.blocks):
# print 'warning, {} blocks unreachable!'.format(len(self.blocks) - len(reached))
self.blocks = reached
def replace(self, replace):
flattenDict(replace)
for block in self.blocks:
block.replace(replace)
def simplify(self):
for block in self.blocks:
block.simplify()
def makeGraph(root):
g = DUGraph()
g.makeCFG(root)
return g

View File

@ -0,0 +1,127 @@
from collections import defaultdict as ddict
import itertools
def unique(seq): return len(set(seq)) == len(seq)
# This module provides a view of the ssa graph that can be modified without
# touching the underlying graph. This proxy is tailored towards the need of
# cfg structuring, so it allows easy duplication and indirection of nodes,
# but assumes that the underlying variables and statements are immutable
class BlockProxy(object):
def __init__(self, key, counter, block=None):
self.bkey = key
self.num = next(counter)
self.counter = counter
self.block = block
self.predecessors = []
self.successors = []
self.outvars = {}
self.eassigns = {} # exception edge assignments, used after try constraint creation
self._key = self.bkey, self.num
# to be assigned later
self.invars = self.blockdict = None
# assigned by structuring.py calcNoLoopNeighbors
self.successors_nl = self.predecessors_nl = self.norm_suc_nl = None
def replaceSuccessors(self, rmap):
update = lambda k:rmap.get(k,k)
self.successors = map(update, self.successors)
self.outvars = {update(k):v for k,v in self.outvars.items()}
if self.block is not None:
d1 = self.blockdict
self.blockdict = {(b.key,t):update(d1[b.key,t]) for (b,t) in self.block.jump.getSuccessorPairs()}
def newIndirect(self): # for use during graph creation
new = BlockProxy(self.bkey, self.counter)
new.invars = self.invars
new.outvars = {self:new.invars}
new.blockdict = None
new.successors = [self]
self.predecessors.append(new)
return new
def newDuplicate(self): # for use by structuring.structure return inlining
new = BlockProxy(self.bkey, self.counter, self.block)
new.invars = self.invars
new.outvars = self.outvars.copy()
new.blockdict = self.blockdict
new.successors = self.successors[:]
return new
def indirectEdges(self, edges):
# Should only be called once graph is completely set up. newIndirect is used during graph creation
new = self.newIndirect()
for parent in edges:
self.predecessors.remove(parent)
new.predecessors.append(parent)
parent.replaceSuccessors({self:new})
return new
def normalSuccessors(self): # only works once try constraints have been created
return [x for x in self.successors if x in self.outvars]
def __str__(self): # pragma: no cover
fmt = 'PB {}x{}' if self.num else 'PB {0}'
return fmt.format(self.bkey, self.num)
__repr__ = __str__
def createGraphProxy(ssagraph):
assert(not ssagraph.procs) # should have already been inlined
nodes = [BlockProxy(b.key, itertools.count(), block=b) for b in ssagraph.blocks]
allnodes = nodes[:] # will also contain indirected nodes
entryNode = None
intypes = ddict(set)
for n in nodes:
invars = [phi.rval for phi in n.block.phis]
for b, t in n.block.jump.getSuccessorPairs():
intypes[b.key].add(t)
if n.bkey == ssagraph.entryKey:
assert(not entryNode and not invars) # shouldn't have more than one entryBlock and entryBlock shouldn't have phis
entryNode = n
invars = ssagraph.inputArgs # store them in the node so we don't have to keep track seperately
invars = [x for x in invars if x is not None] # will have None placeholders for Long and Double arguments
n.invars = invars
lookup = {}
for n in nodes:
assert len(intypes[n.bkey]) != 2 # should have been handled by graph.splitDualInedges()
if False in intypes[n.bkey]:
lookup[n.bkey, False] = n
if True in intypes[n.bkey]:
lookup[n.bkey, True] = n
assert unique(lookup.values())
for n in nodes:
n.blockdict = lookup
block = n.block
for (block2, t) in block.jump.getSuccessorPairs():
out = [phi.get((block, t)) for phi in block2.phis]
n2 = lookup[block2.key, t]
n.outvars[n2] = out
n.successors.append(n2)
n2.predecessors.append(n)
# sanity check
for n in allnodes:
assert (n.block is not None) == (n.num == 0)
assert (n is entryNode) == (len(n.predecessors) == 0)
assert unique(n.predecessors)
assert unique(n.successors)
for pn in n.predecessors:
assert n in pn.successors
assert set(n.outvars) == set(n.successors)
for sn in n.successors:
assert n in sn.predecessors
assert len(n.outvars[sn]) == len(sn.invars)
return entryNode, allnodes

View File

@ -0,0 +1,70 @@
import struct
from ..ssa import objtypes
from ..verifier.descriptors import parseFieldDescriptor
from . import ast, ast2, javamethod, throws
from .reserved import reserved_identifiers
def loadConstValue(cpool, index):
entry_type = cpool.pool[index][0]
args = cpool.getArgs(index)
# Note: field constant values cannot be class literals
tt = {'Int':objtypes.IntTT, 'Long':objtypes.LongTT,
'Float':objtypes.FloatTT, 'Double':objtypes.DoubleTT,
'String':objtypes.StringTT}[entry_type]
return ast.Literal(tt, args[0]).fixLiterals()
def _getField(field):
flags = [x.lower() for x in sorted(field.flags) if x not in ('SYNTHETIC','ENUM')]
desc = field.descriptor
dtype = objtypes.verifierToSynthetic(parseFieldDescriptor(desc, unsynthesize=False)[0])
initexpr = None
if field.static:
cpool = field.class_.cpool
const_attrs = [data for name,data in field.attributes if name == 'ConstantValue']
if const_attrs:
assert len(const_attrs) == 1
data = const_attrs[0]
index = struct.unpack('>h', data)[0]
initexpr = loadConstValue(cpool, index)
return ast2.FieldDef(' '.join(flags), ast.TypeName(dtype), field.class_, field.name, desc, initexpr)
def _getMethod(method, cb, forbidden_identifiers, skip_errors):
try:
graph = cb(method) if method.code is not None else None
print 'Decompiling method', method.name.encode('utf8'), method.descriptor.encode('utf8')
code_ast = javamethod.generateAST(method, graph, forbidden_identifiers)
return code_ast
except Exception as e:
if not skip_errors:
raise
import traceback
message = traceback.format_exc()
code_ast = javamethod.generateAST(method, None, forbidden_identifiers)
code_ast.comments.add(message)
print message
return code_ast
# Method argument allows decompilng only a single method, primarily useful for debugging
def generateAST(cls, cb, skip_errors, method=None, add_throws=False):
methods = cls.methods if method is None else [cls.methods[method]]
fi = set(reserved_identifiers)
for field in cls.fields:
fi.add(field.name)
forbidden_identifiers = frozenset(fi)
myflags = [x.lower() for x in sorted(cls.flags) if x not in ('INTERFACE','SUPER','SYNTHETIC','ANNOTATION','ENUM')]
isInterface = 'INTERFACE' in cls.flags
superc = cls.supername
interfaces = [cls.cpool.getArgsCheck('Class', index) for index in cls.interfaces_raw] # todo - change when class actually loads interfaces
field_defs = [_getField(f) for f in cls.fields]
method_defs = [_getMethod(m, cb, forbidden_identifiers, skip_errors) for m in methods]
if add_throws:
throws.addSingle(cls.env, method_defs)
return ast2.ClassDef(' '.join(myflags), isInterface, cls.name, superc, interfaces, field_defs, method_defs)

View File

@ -0,0 +1,886 @@
import collections
from functools import partial
import operator
from .. import graph_util
from ..namegen import NameGen
from ..ssa import objtypes
from ..verifier.descriptors import parseMethodDescriptor
from . import ast, ast2, astgen, boolize, graphproxy, mergevariables, structuring
class DeclInfo(object):
__slots__ = "declScope scope defs".split()
def __init__(self):
self.declScope = self.scope = None
self.defs = []
def findVarDeclInfo(root, predeclared):
info = collections.OrderedDict()
def visit(scope, expr):
for param in expr.params:
visit(scope, param)
if expr.isLocalAssign():
left, right = expr.params
info[left].defs.append(right)
elif isinstance(expr, (ast.Local, ast.Literal)):
# this would be so much nicer if we had Ordered defaultdicts
info.setdefault(expr, DeclInfo())
info[expr].scope = ast.StatementBlock.join(info[expr].scope, scope)
def visitDeclExpr(scope, expr):
info.setdefault(expr, DeclInfo())
assert scope is not None and info[expr].declScope is None
info[expr].declScope = scope
for expr in predeclared:
visitDeclExpr(root, expr)
stack = [(root,root)]
while stack:
scope, stmt = stack.pop()
if isinstance(stmt, ast.StatementBlock):
stack.extend((stmt,sub) for sub in stmt.statements)
else:
stack.extend((subscope,subscope) for subscope in stmt.getScopes())
# temp hack
if stmt.expr is not None:
visit(scope, stmt.expr)
if isinstance(stmt, ast.TryStatement):
for catchdecl, body in stmt.pairs:
visitDeclExpr(body, catchdecl.local)
return info
def reverseBoolExpr(expr):
assert expr.dtype == objtypes.BoolTT
if isinstance(expr, ast.BinaryInfix):
symbols = "== != < >= > <=".split()
floatts = (objtypes.FloatTT, objtypes.DoubleTT)
if expr.opstr in symbols:
sym2 = symbols[symbols.index(expr.opstr) ^ 1]
left, right = expr.params
# be sure not to reverse floating point comparisons since it's not equivalent for NaN
if expr.opstr in symbols[:2] or (left.dtype not in floatts and right.dtype not in floatts):
return ast.BinaryInfix(sym2, [left, right], objtypes.BoolTT)
elif isinstance(expr, ast.UnaryPrefix) and expr.opstr == '!':
return expr.params[0]
return ast.UnaryPrefix('!', expr)
def getSubscopeIter(root):
stack = [root]
while stack:
scope = stack.pop()
if isinstance(scope, ast.StatementBlock):
stack.extend(scope.statements)
yield scope
else:
stack.extend(scope.getScopes())
def mayBreakTo(root, forbidden):
assert None not in forbidden
for scope in getSubscopeIter(root):
if scope.jumpKey in forbidden:
# We return true if scope has forbidden jump and is reachable
# We assume there is no unreachable code, so in order for a scope
# jump to be unreachable, it must end in a return, throw, or a
# compound statement, all of which are not reachable or do not
# break out of the statement. We omit adding last.breakKey to
# forbidden since it should always match scope.jumpKey anyway
if not scope.statements:
return True
last = scope.statements[-1]
if not last.getScopes():
if not isinstance(last, (ast.ReturnStatement, ast.ThrowStatement)):
return True
else:
# If and switch statements may allow fallthrough
# A while statement with condition may break implicitly
if isinstance(last, ast.IfStatement) and len(last.getScopes()) == 1:
return True
if isinstance(last, ast.SwitchStatement) and not last.hasDefault():
return True
if isinstance(last, ast.WhileStatement) and last.expr != ast.Literal.TRUE:
return True
if not isinstance(last, ast.WhileStatement):
for sub in last.getScopes():
assert sub.breakKey == last.breakKey == scope.jumpKey
return False
def replaceKeys(top, replace):
assert None not in replace
get = lambda k:replace.get(k,k)
if top.getScopes():
if isinstance(top, ast.StatementBlock) and get(top.breakKey) is None:
# breakKey can be None with non-None jumpKey when we're a scope in a switch statement that falls through
# and the end of the switch statement is unreachable
assert get(top.jumpKey) is None or not top.labelable
top.breakKey = get(top.breakKey)
if isinstance(top, ast.StatementBlock):
top.jumpKey = get(top.jumpKey)
for item in top.statements:
replaceKeys(item, replace)
else:
for scope in top.getScopes():
replaceKeys(scope, replace)
NONE_SET = frozenset([None])
def _preorder(scope, func):
newitems = []
for i, item in enumerate(scope.statements):
for sub in item.getScopes():
_preorder(sub, func)
val = func(scope, item)
vals = [item] if val is None else val
newitems.extend(vals)
scope.statements = newitems
def _fixObjectCreations(scope, item):
'''Combines new/invokeinit pairs into Java constructor calls'''
# Thanks to the uninitialized variable merging prior to AST generation,
# we can safely assume there are no copies to worry about
expr = item.expr
if isinstance(expr, ast.Assignment):
left, right = expr.params
if isinstance(right, ast.Dummy) and right.isNew:
return [] # remove item
elif isinstance(expr, ast.MethodInvocation) and expr.name == '<init>':
left = expr.params[0]
newexpr = ast.ClassInstanceCreation(ast.TypeName(left.dtype), expr.tts[1:], expr.params[1:])
item.expr = ast.Assignment(left, newexpr)
def _pruneRethrow_cb(item):
'''Convert try{A} catch(T t) {throw t;} to {A}'''
while item.pairs:
decl, body = item.pairs[-1]
caught, lines = decl.local, body.statements
if len(lines) == 1:
line = lines[0]
if isinstance(line, ast.ThrowStatement) and line.expr == caught:
item.pairs = item.pairs[:-1]
continue
break
if not item.pairs:
new = item.tryb
assert new.breakKey == item.breakKey
assert new.continueKey == item.continueKey
assert not new.labelable
new.labelable = True
return new
return item
def _pruneIfElse_cb(item):
'''Convert if(A) {B} else {} to if(A) {B}'''
if len(item.scopes) > 1:
tblock, fblock = item.scopes
# if true block is empty, swap it with false so we can remove it
if not tblock.statements and tblock.doesFallthrough():
item.expr = reverseBoolExpr(item.expr)
tblock, fblock = item.scopes = fblock, tblock
if not fblock.statements and fblock.doesFallthrough():
item.scopes = tblock,
# if(A) {B throw/return ... } else {C} -> if(A) {B throw/return ...} {C}
if len(item.scopes) > 1:
# How much we want the block to be first, or (3, 0) if it can't be simplified
def priority(block):
# If an empty block survives to this point, it must end in a break so we can simplify in this case too
if not block.statements:
return 2, 0
# If any complex statements, there might be a labeled break, so it's not safe
if any(stmt.getScopes() for stmt in block.statements):
return 3, 0
# prefer if(...) {throw...} return... to if(...) {return...} throw...
if isinstance(block.statements[-1], ast.ThrowStatement):
return 0, len(block.statements)
elif isinstance(block.statements[-1], ast.ReturnStatement):
return 1, len(block.statements)
return 3, 0
if priority(fblock) < priority(tblock):
item.expr = reverseBoolExpr(item.expr)
tblock, fblock = item.scopes = fblock, tblock
if priority(tblock) < (3, 0):
assert tblock.statements or not tblock.doesFallthrough()
item.scopes = tblock,
item.breakKey = fblock.continueKey
fblock.labelable = True
return [item], fblock
# If cond is !(x), reverse it back to simplify cond
if isinstance(item.expr, ast.UnaryPrefix) and item.expr.opstr == '!':
item.expr = reverseBoolExpr(item.expr)
tblock, fblock = item.scopes = fblock, tblock
# if(A) {if(B) {C}} -> if(A && B) {C}
tblock = item.scopes[0]
if len(item.scopes) == 1 and len(tblock.statements) == 1 and tblock.doesFallthrough():
first = tblock.statements[0]
if isinstance(first, ast.IfStatement) and len(first.scopes) == 1:
item.expr = ast.BinaryInfix('&&',[item.expr, first.expr], objtypes.BoolTT)
item.scopes = first.scopes
return [], item
def _whileCondition_cb(item):
'''Convert while(true) {if(A) {B break;} else {C} D} to while(!A) {{C} D} {B}
and while(A) {if(B) {break;} else {C} D} to while(A && !B) {{C} D}'''
failure = [], item # what to return if we didn't inline
body = item.getScopes()[0]
if not body.statements or not isinstance(body.statements[0], ast.IfStatement):
return failure
head = body.statements[0]
cond = head.expr
trueb, falseb = (head.getScopes() + (None,))[:2]
# Make sure it doesn't continue the loop or break out of the if statement
badjumps1 = frozenset([head.breakKey, item.continueKey]) - NONE_SET
if mayBreakTo(trueb, badjumps1):
if falseb is not None and not mayBreakTo(falseb, badjumps1):
cond = reverseBoolExpr(cond)
trueb, falseb = falseb, trueb
else:
return failure
assert not mayBreakTo(trueb, badjumps1)
trivial = not trueb.statements and trueb.jumpKey == item.breakKey
# If we already have a condition, only a simple break is allowed
if not trivial and item.expr != ast.Literal.TRUE:
return failure
# If break body is nontrival, we can't insert this after the end of the loop unless
# We're sure that nothing else in the loop breaks out
badjumps2 = frozenset([item.breakKey]) - NONE_SET
if not trivial:
restloop = [falseb] if falseb is not None else []
restloop += body.statements[1:]
if body.jumpKey == item.breakKey or any(mayBreakTo(s, badjumps2) for s in restloop):
return failure
# Now inline everything
item.expr = _simplifyExpressions(ast.BinaryInfix('&&', [item.expr, reverseBoolExpr(cond)]))
if falseb is None:
body.continueKey = body.statements.pop(0).breakKey
else:
body.continueKey = falseb.continueKey
body.statements[0] = falseb
falseb.labelable = True
trueb.labelable = True
if item.breakKey is None: # Make sure to maintain invariant that bkey=None -> jkey=None
assert trueb.doesFallthrough()
trueb.jumpKey = trueb.breakKey = None
trueb.breakKey = item.breakKey
assert trueb.continueKey is not None
if not trivial:
item.breakKey = trueb.continueKey
# Trueb doesn't break to head.bkey but there might be unreacahble jumps, so we replace
# it too. We don't replace item.ckey because it should never appear, even as an
# unreachable jump
replaceKeys(trueb, {head.breakKey:trueb.breakKey, item.breakKey:trueb.breakKey})
return [item], trueb
def _simplifyBlocksSub(scope, item, isLast):
rest = []
if isinstance(item, ast.TryStatement):
item = _pruneRethrow_cb(item)
elif isinstance(item, ast.IfStatement):
rest, item = _pruneIfElse_cb(item)
elif isinstance(item, ast.WhileStatement):
rest, item = _whileCondition_cb(item)
if isinstance(item, ast.StatementBlock):
assert item.breakKey is not None or item.jumpKey is None
# If bkey is None, it can't be broken to
# If contents can also break to enclosing scope, it's always safe to inline
bkey = item.breakKey
if bkey is None or (bkey == scope.breakKey and scope.labelable):
rest, item.statements = rest + item.statements, []
# Now inline statements at the beginning of the scope that don't need to break to it
for sub in item.statements[:]:
if sub.getScopes() and sub.breakKey != bkey and mayBreakTo(sub, frozenset([bkey])):
break
rest.append(item.statements.pop(0))
if not item.statements:
if item.jumpKey != bkey:
assert isLast
scope.jumpKey = item.jumpKey
assert scope.breakKey is not None or scope.jumpKey is None
return rest
return rest + [item]
def _simplifyBlocks(scope):
newitems = []
for item in reversed(scope.statements):
isLast = not newitems # may be true if all subsequent items pruned
if isLast and item.getScopes():
if item.breakKey != scope.jumpKey:# and item.breakKey is not None:
replaceKeys(item, {item.breakKey: scope.jumpKey})
for sub in reversed(item.getScopes()):
_simplifyBlocks(sub)
vals = _simplifyBlocksSub(scope, item, isLast)
newitems += reversed(vals)
scope.statements = newitems[::-1]
_op2bits = {'==':2, '!=':13, '<':1, '<=':3, '>':4, '>=':6}
_bit2ops_float = {v:k for k,v in _op2bits.items()}
_bit2ops = {(v & 7):k for k,v in _op2bits.items()}
def _getBitfield(expr):
if isinstance(expr, ast.BinaryInfix):
if expr.opstr in ('==','!=','<','<=','>','>='):
# We don't want to merge expressions if they could have side effects
# so only allow literals and locals
if all(isinstance(p, (ast.Literal, ast.Local)) for p in expr.params):
return _op2bits[expr.opstr], list(expr.params)
elif expr.opstr in ('&','&&','|','||'):
bits1, args1 = _getBitfield(expr.params[0])
bits2, args2 = _getBitfield(expr.params[1])
if args1 == args2:
bits = (bits1 & bits2) if '&' in expr.opstr else (bits1 | bits2)
return bits, args1
elif isinstance(expr, ast.UnaryPrefix) and expr.opstr == '!':
bits, args = _getBitfield(expr.params[0])
return ~bits, args
return 0, None
def _mergeComparisons(expr):
# a <= b && a != b -> a < b, etc.
bits, args = _getBitfield(expr)
if args is None:
return expr
assert not hasSideEffects(args[0]) and not hasSideEffects(args[1])
if args[0].dtype in (objtypes.FloatTT, objtypes.DoubleTT):
mask, d = 15, _bit2ops_float
else:
mask, d = 7, _bit2ops
bits &= mask
notbits = (~bits) & mask
if bits == 0:
return ast.Literal.TRUE
elif notbits == 0:
return ast.Literal.FALSE
elif bits in d:
return ast.BinaryInfix(d[bits], args, objtypes.BoolTT)
elif notbits in d:
return ast.UnaryPrefix('!', ast.BinaryInfix(d[notbits], args, objtypes.BoolTT))
return expr
def _simplifyExpressions(expr):
TRUE, FALSE = ast.Literal.TRUE, ast.Literal.FALSE
bools = {True:TRUE, False:FALSE}
opfuncs = {'<': operator.lt, '<=': operator.le, '>': operator.gt, '>=': operator.ge}
simplify = _simplifyExpressions
expr.params = map(simplify, expr.params)
if isinstance(expr, ast.BinaryInfix):
left, right = expr.params
op = expr.opstr
if op in ('==','!=','<','<=','>','>=') and isinstance(right, ast.Literal):
# la cmp lb -> result (i.e. constant propagation on literal comparisons)
if isinstance(left, ast.Literal):
if op in ('==','!='):
# these could be string or class literals, but those are always nonnull so it still works
res = (left == right) == (op == '==')
else:
assert left.dtype == right.dtype
res = opfuncs[op](left.val, right.val)
expr = bools[res]
# (a ? lb : c) cmp ld -> a ? (lb cmp ld) : (c cmp ld)
elif isinstance(left, ast.Ternary) and isinstance(left.params[1], ast.Literal):
left.params[1] = simplify(ast.BinaryInfix(op, [left.params[1], right], expr._dtype))
left.params[2] = simplify(ast.BinaryInfix(op, [left.params[2], right], expr._dtype))
expr = left
# a ? true : b -> a || b
# a ? false : b -> !a && b
if isinstance(expr, ast.Ternary) and expr.dtype == objtypes.BoolTT:
cond, val1, val2 = expr.params
if not isinstance(val1, ast.Literal): # try to get bool literal to the front
cond, val1, val2 = reverseBoolExpr(cond), val2, val1
if val1 == TRUE:
expr = ast.BinaryInfix('||', [cond, val2], objtypes.BoolTT)
elif val1 == FALSE:
expr = ast.BinaryInfix('&&', [reverseBoolExpr(cond), val2], objtypes.BoolTT)
# true && a -> a, etc.
if isinstance(expr, ast.BinaryInfix) and expr.opstr in ('&&','||'):
left, right = expr.params
if expr.opstr == '&&':
if left == TRUE or (right == FALSE and not hasSideEffects(left)):
expr = right
elif left == FALSE or right == TRUE:
expr = left
else:
if left == TRUE or right == FALSE:
expr = left
elif left == FALSE or (right == TRUE and not hasSideEffects(left)):
expr = right
# a > b || a == b -> a >= b, etc.
expr = _mergeComparisons(expr)
# a == true -> a
# a == false -> !a
if isinstance(expr, ast.BinaryInfix) and expr.opstr in ('==, !=') and expr.params[0].dtype == objtypes.BoolTT:
left, right = expr.params
if not isinstance(left, ast.Literal): # try to get bool literal to the front
left, right = right, left
if isinstance(left, ast.Literal):
flip = (left == TRUE) != (expr.opstr == '==')
expr = reverseBoolExpr(right) if flip else right
# !a ? b : c -> a ? c : b
if isinstance(expr, ast.Ternary) and isinstance(expr.params[0], ast.UnaryPrefix):
cond, val1, val2 = expr.params
if cond.opstr == '!':
expr.params = [reverseBoolExpr(cond), val2, val1]
# 0 - a -> -a
if isinstance(expr, ast.BinaryInfix) and expr.opstr == '-':
if expr.params[0] == ast.Literal.ZERO or expr.params[0] == ast.Literal.LZERO:
expr = ast.UnaryPrefix('-', expr.params[1])
# (double)4.2f -> 4.2, etc.
if isinstance(expr, ast.Cast) and isinstance(expr.params[1], ast.Literal):
expr = ast.makeCastExpr(expr.dtype, expr.params[1])
return expr
def _setScopeParents(scope):
for item in scope.statements:
for sub in item.getScopes():
sub.bases = scope.bases + (sub,)
_setScopeParents(sub)
def _replaceExpressions(scope, item, rdict):
# Must be done before local declarations are created since it doesn't touch/remove them
if item.expr is not None:
item.expr = item.expr.replaceSubExprs(rdict)
# remove redundant assignments i.e. x=x;
if isinstance(item.expr, ast.Assignment):
assert isinstance(item, ast.ExpressionStatement)
left, right = item.expr.params
if left == right:
return []
return [item]
def _oldMergeVariables(root, predeclared):
_setScopeParents(root)
info = findVarDeclInfo(root, predeclared)
lvars = [expr for expr in info if isinstance(expr, ast.Local)]
forbidden = set()
# If var has any defs which aren't a literal or local, mark it as a leaf node (it can't be merged into something)
for var in lvars:
if not all(isinstance(expr, (ast.Local, ast.Literal)) for expr in info[var].defs):
forbidden.add(var)
elif info[var].declScope is not None:
forbidden.add(var)
sccs = graph_util.tarjanSCC(lvars, lambda var:([] if var in forbidden else info[var].defs))
# the sccs will be in topolgical order
varmap = {}
for scc in sccs:
if forbidden.isdisjoint(scc):
alldefs = []
for expr in scc:
for def_ in info[expr].defs:
if def_ not in scc:
alldefs.append(varmap[def_])
if len(set(alldefs)) == 1:
target = alldefs[0]
if all(var.dtype == target.dtype for var in scc):
scope = ast.StatementBlock.join(*(info[var].scope for var in scc))
scope = ast.StatementBlock.join(scope, info[target].declScope) # scope is unchanged if declScope is none like usual
if info[target].declScope is None or info[target].declScope == scope:
for var in scc:
varmap[var] = target
info[target].scope = ast.StatementBlock.join(scope, info[target].scope)
continue
# fallthrough if merging is impossible
for var in scc:
varmap[var] = var
if len(info[var].defs) > 1:
forbidden.add(var)
_preorder(root, partial(_replaceExpressions, rdict=varmap))
def _mergeVariables(root, predeclared, isstatic):
_oldMergeVariables(root, predeclared)
rdict = mergevariables.mergeVariables(root, isstatic, predeclared)
_preorder(root, partial(_replaceExpressions, rdict=rdict))
_oktypes = ast.BinaryInfix, ast.Local, ast.Literal, ast.Parenthesis, ast.Ternary, ast.TypeName, ast.UnaryPrefix
def hasSideEffects(expr):
if not isinstance(expr, _oktypes):
return True
# check for division by 0. If it's a float or dividing by nonzero literal, it's ok
elif isinstance(expr, ast.BinaryInfix) and expr.opstr in ('/','%'):
if expr.dtype not in (objtypes.FloatTT, objtypes.DoubleTT):
divisor = expr.params[-1]
if not isinstance(divisor, ast.Literal) or divisor.val == 0:
return True
return False
def _inlineVariables(root):
# first find all variables with a single def and use
defs = collections.defaultdict(list)
uses = collections.defaultdict(int)
def visitExprFindDefs(expr):
if expr.isLocalAssign():
defs[expr.params[0]].append(expr)
elif isinstance(expr, ast.Local):
uses[expr] += 1
def visitFindDefs(scope, item):
if item.expr is not None:
stack = [item.expr]
while stack:
expr = stack.pop()
visitExprFindDefs(expr)
stack.extend(expr.params)
_preorder(root, visitFindDefs)
# These should have 2 uses since the initial assignment also counts
replacevars = {k for k,v in defs.items() if len(v)==1 and uses[k]==2 and k.dtype == v[0].params[1].dtype}
def doReplacement(item, pairs):
old, new = item.expr.params
assert isinstance(old, ast.Local) and old.dtype == new.dtype
stack = [(True, (True, item2, expr)) for item2, expr in reversed(pairs) if expr is not None]
while stack:
recurse, args = stack.pop()
if recurse:
canReplace, parent, expr = args
stack.append((False, expr))
# TODO - fix this for real
if expr.complexity() > 30:
canReplace = False
# For ternaries, we don't want to replace into the conditionally
# evaluated part, but we still need to check those parts for
# barriers. For both ternaries and short circuit operators, the
# first param is always evaluated, so it is safe
if isinstance(expr, ast.Ternary) or isinstance(expr, ast.BinaryInfix) and expr.opstr in ('&&','||'):
for param in reversed(expr.params[1:]):
stack.append((True, (False, expr, param)))
stack.append((True, (canReplace, expr, expr.params[0])))
# For assignments, we unroll the LHS arguments, because if assigning
# to an array or field, we don't want that to serve as a barrier
elif isinstance(expr, ast.Assignment):
left, right = expr.params
stack.append((True, (canReplace, expr, right)))
if isinstance(left, (ast.ArrayAccess, ast.FieldAccess)):
for param in reversed(left.params):
stack.append((True, (canReplace, left, param)))
else:
assert isinstance(left, ast.Local)
else:
for param in reversed(expr.params):
stack.append((True, (canReplace, expr, param)))
if expr == old:
if canReplace:
if isinstance(parent, ast.JavaExpression):
params = parent.params = list(parent.params)
params[params.index(old)] = new
else: # replacing in a top level statement
assert parent.expr == old
parent.expr = new
return canReplace
else:
expr = args
if hasSideEffects(expr):
return False
return False
def visitReplace(scope):
newstatements = []
for item in reversed(scope.statements):
for sub in item.getScopes():
visitReplace(sub)
if isinstance(item.expr, ast.Assignment) and item.expr.params[0] in replacevars:
expr_roots = []
for item2 in newstatements:
# Don't inline into a while condition as it may be evaluated more than once
if not isinstance(item2, ast.WhileStatement):
expr_roots.append((item2, item2.expr))
if item2.getScopes():
break
success = doReplacement(item, expr_roots)
if success:
continue
newstatements.insert(0, item)
scope.statements = newstatements
visitReplace(root)
def _createDeclarations(root, predeclared):
_setScopeParents(root)
info = findVarDeclInfo(root, predeclared)
localdefs = collections.defaultdict(list)
newvars = [var for var in info if isinstance(var, ast.Local) and info[var].declScope is None]
remaining = set(newvars)
def mdVisitVarUse(var):
decl = ast.VariableDeclarator(ast.TypeName(var.dtype), var)
# The compiler treats statements as if they can throw any exception at any time, so
# it may think variables are not definitely assigned even when they really are.
# Therefore, we give an unused initial value to every variable declaration
# TODO - find a better way to handle this
right = ast.dummyLiteral(var.dtype)
localdefs[info[var].scope].append(ast.LocalDeclarationStatement(decl, right))
remaining.remove(var)
def mdVisitScope(scope):
if isinstance(scope, ast.StatementBlock):
for i,stmt in enumerate(scope.statements):
if isinstance(stmt, ast.ExpressionStatement):
if isinstance(stmt.expr, ast.Assignment):
var, right = stmt.expr.params
if var in remaining and scope == info[var].scope:
decl = ast.VariableDeclarator(ast.TypeName(var.dtype), var)
new = ast.LocalDeclarationStatement(decl, right)
scope.statements[i] = new
remaining.remove(var)
if stmt.expr is not None:
top = stmt.expr
for expr in top.postFlatIter():
if expr in remaining:
mdVisitVarUse(expr)
for sub in stmt.getScopes():
mdVisitScope(sub)
mdVisitScope(root)
# print remaining
assert not remaining
assert None not in localdefs
for scope, ldefs in localdefs.items():
scope.statements = ldefs + scope.statements
def _createTernaries(scope, item):
if isinstance(item, ast.IfStatement) and len(item.getScopes()) == 2:
block1, block2 = item.getScopes()
if (len(block1.statements) == len(block2.statements) == 1) and block1.jumpKey == block2.jumpKey:
s1, s2 = block1.statements[0], block2.statements[0]
e1, e2 = s1.expr, s2.expr
if isinstance(s1, ast.ReturnStatement) and isinstance(s2, ast.ReturnStatement):
expr = None if e1 is None else ast.Ternary(item.expr, e1, e2)
item = ast.ReturnStatement(expr, s1.tt)
elif isinstance(s1, ast.ExpressionStatement) and isinstance(s2, ast.ExpressionStatement):
if isinstance(e1, ast.Assignment) and isinstance(e2, ast.Assignment):
# if e1.params[0] == e2.params[0] and max(e1.params[1].complexity(), e2.params[1].complexity()) <= 1:
if e1.params[0] == e2.params[0]:
expr = ast.Ternary(item.expr, e1.params[1], e2.params[1])
temp = ast.ExpressionStatement(ast.Assignment(e1.params[0], expr))
if not block1.doesFallthrough():
assert not block2.doesFallthrough()
item = ast.StatementBlock(item.func, item.continueKey, item.breakKey, [temp], block1.jumpKey)
else:
item = temp
if item.expr is not None:
item.expr = _simplifyExpressions(item.expr)
return [item]
def _fixExprStatements(scope, item, namegen):
if isinstance(item, ast.ExpressionStatement):
right = item.expr
if not isinstance(right, (ast.Assignment, ast.ClassInstanceCreation, ast.MethodInvocation)) and right.dtype is not None:
left = ast.Local(right.dtype, lambda expr:namegen.getPrefix('dummy'))
decl = ast.VariableDeclarator(ast.TypeName(left.dtype), left)
item = ast.LocalDeclarationStatement(decl, right)
return [item]
def _fixLiterals(scope, item):
item.fixLiterals()
def _addCastsAndParens(scope, item, env):
item.addCastsAndParens(env)
def _fallsThrough(scope, usedBreakTargets):
# Check if control reaches end of scope and there is no break statement
# We don't have to check keys since breaks should have already been generated for child scopes
# for main scope there won't be one yet, but we don't care since we're just looking for
# whether end of scope is reached on the main scope
if not scope.statements:
return True
last = scope.statements[-1]
if isinstance(last, (ast.JumpStatement, ast.ReturnStatement, ast.ThrowStatement)):
return False
elif not last.getScopes():
return True
# Scope ends with a complex statement. Determine whether it can fallthrough
if last in usedBreakTargets:
return True
# Less strict than Java reachability rules, but we aren't bothering to follow them exactly
if isinstance(last, ast.WhileStatement):
return last.expr != ast.Literal.TRUE
elif isinstance(last, ast.SwitchStatement):
return not last.hasDefault() or _fallsThrough(last.getScopes()[-1], usedBreakTargets)
else:
if isinstance(last, ast.IfStatement) and len(last.getScopes()) < 2:
return True
return any(_fallsThrough(sub, usedBreakTargets) for sub in last.getScopes())
def _chooseJump(choices, breakPair, continuePair):
assert None not in choices
if breakPair in choices:
return breakPair
if continuePair in choices:
return continuePair
# Try to find an already labeled target
for b, t in choices:
if b.label is not None:
return b, t
return choices[0]
def _generateJumps(scope, usedBreakTargets, targets=collections.defaultdict(tuple), breakPair=None, continuePair=None, fallthroughs=NONE_SET, dryRun=False):
assert None in fallthroughs
newfallthroughs = fallthroughs
newcontinuePair = continuePair
newbreakPair = breakPair
if scope.jumpKey not in fallthroughs:
newfallthroughs = frozenset([None, scope.jumpKey])
for item in reversed(scope.statements):
if not item.getScopes():
newfallthroughs = NONE_SET
continue
if isinstance(item, ast.WhileStatement):
newfallthroughs = frozenset([None, item.continueKey])
else:
newfallthroughs |= frozenset([item.breakKey])
newtargets = targets.copy()
if isinstance(item, ast.WhileStatement):
newtargets[item.continueKey] += ((item, True),)
newcontinuePair = item, True
if isinstance(item, (ast.WhileStatement, ast.SwitchStatement)):
newbreakPair = item, False
newtargets[item.breakKey] += ((item, False),)
for subscope in reversed(item.getScopes()):
_generateJumps(subscope, usedBreakTargets, newtargets, newbreakPair, newcontinuePair, newfallthroughs, dryRun=dryRun)
if isinstance(item, ast.SwitchStatement):
newfallthroughs = frozenset([None, subscope.continueKey])
newfallthroughs = frozenset([None, item.continueKey])
for item in scope.statements:
if isinstance(item, ast.StatementBlock) and item.statements:
if isinstance(item.statements[-1], ast.JumpStatement):
assert item is scope.statements[-1] or item in usedBreakTargets
# Now that we've visited children, decide if we need a jump for this scope, and if so, generate it
if scope.jumpKey not in fallthroughs:
# Figure out if this jump is actually reachable
if _fallsThrough(scope, usedBreakTargets):
target, isContinue = pair = _chooseJump(targets[scope.jumpKey], breakPair, continuePair)
if not isContinue:
usedBreakTargets.add(target)
if pair == breakPair or pair == continuePair:
target = None
# Now actually add the jump statement
if not dryRun:
scope.statements.append(ast.JumpStatement(target, isContinue))
def _pruneVoidReturn(scope):
if scope.statements:
last = scope.statements[-1]
if isinstance(last, ast.ReturnStatement) and last.expr is None:
scope.statements.pop()
def _pruneCtorSuper(scope):
if scope.statements:
stmt = scope.statements[0]
if isinstance(stmt, ast.ExpressionStatement):
expr = stmt.expr
if isinstance(expr, ast.MethodInvocation) and len(expr.params) == 0 and expr.name == 'super':
scope.statements = scope.statements[1:]
def generateAST(method, graph, forbidden_identifiers):
env = method.class_.env
namegen = NameGen(forbidden_identifiers)
class_ = method.class_
inputTypes = parseMethodDescriptor(method.descriptor, unsynthesize=False)[0]
tts = objtypes.verifierToSynthetic_seq(inputTypes)
if graph is not None:
graph.splitDualInedges()
graph.fixLoops()
entryNode, nodes = graphproxy.createGraphProxy(graph)
if not method.static:
entryNode.invars[0].name = 'this'
setree = structuring.structure(entryNode, nodes, (method.name == '<clinit>'))
ast_root, varinfo = astgen.createAST(method, graph, setree, namegen)
argsources = [varinfo.var(entryNode, var) for var in entryNode.invars]
disp_args = argsources if method.static else argsources[1:]
for expr, tt in zip(disp_args, tts):
expr.dtype = tt
decls = [ast.VariableDeclarator(ast.TypeName(expr.dtype), expr) for expr in disp_args]
################################################################################################
ast_root.bases = (ast_root,) # needed for our setScopeParents later
assert _generateJumps(ast_root, set(), dryRun=True) is None
_preorder(ast_root, _fixObjectCreations)
boolize.boolizeVars(ast_root, argsources)
boolize.interfaceVars(env, ast_root, argsources)
_simplifyBlocks(ast_root)
assert _generateJumps(ast_root, set(), dryRun=True) is None
_mergeVariables(ast_root, argsources, method.static)
_preorder(ast_root, _createTernaries)
_inlineVariables(ast_root)
_simplifyBlocks(ast_root)
_preorder(ast_root, _createTernaries)
_inlineVariables(ast_root)
_simplifyBlocks(ast_root)
_createDeclarations(ast_root, argsources)
_preorder(ast_root, partial(_fixExprStatements, namegen=namegen))
_preorder(ast_root, _fixLiterals)
_preorder(ast_root, partial(_addCastsAndParens, env=env))
_generateJumps(ast_root, set())
_pruneVoidReturn(ast_root)
_pruneCtorSuper(ast_root)
else: # abstract or native method
ast_root = None
argsources = [ast.Local(tt, lambda expr:namegen.getPrefix('arg')) for tt in tts]
decls = [ast.VariableDeclarator(ast.TypeName(expr.dtype), expr) for expr in argsources]
flags = method.flags - set(['BRIDGE','SYNTHETIC','VARARGS'])
if method.name == '<init>': # More arbtirary restrictions. Yay!
flags = flags - set(['ABSTRACT','STATIC','FINAL','NATIVE','STRICTFP','SYNCHRONIZED'])
flagstr = ' '.join(map(str.lower, sorted(flags)))
inputTypes, returnTypes = parseMethodDescriptor(method.descriptor, unsynthesize=False)
ret_tt = objtypes.verifierToSynthetic(returnTypes[0]) if returnTypes else objtypes.VoidTT
return ast2.MethodDef(class_, flagstr, method.name, method.descriptor, ast.TypeName(ret_tt), decls, ast_root)

View File

@ -0,0 +1,285 @@
import heapq
from .cfg import flattenDict, makeGraph
# Variables x and y can safely be merged when it is true that for any use of y (respectively x)
# that sees a definition of y, either there are no intervening definitions of x, or x was known
# to be equal to y *at the point of its most recent definition*
# Given this info, we greedily merge related variables, that is, those where one is assigned to the other
# to calculate which variables can be merged, we first have to build a CFG from the Java AST again
class VarInfo(object):
__slots__ = "key", "defs", "rdefs", "extracount"
def __init__(self, key):
self.key = key
self.defs = set()
self.rdefs = set()
self.extracount = 0
def priority(self):
return (len(self.defs) + self.extracount), self.key
class EqualityData(object):
def __init__(self, d=None):
# Equal values point to a representative object instance. Singletons are not represented at all for efficiency
# None represents the top value (i.e. this point has not been visited yet)
self.d = d.copy() if d is not None else None
def _newval(self): return object()
def initialize(self): # initialize to bottom value (all variables unequal)
assert self.d is None
self.d = {}
def handleAssign(self, var1, var2=None):
if var1 == var2:
return
if var2 is None:
if var1 in self.d:
del self.d[var1]
else:
self.d[var1] = self.d.setdefault(var2, self._newval())
assert self.iseq(var1, var2)
def iseq(self, var1, var2):
assert var1 != var2
return var1 in self.d and var2 in self.d and self.d[var1] is self.d[var2]
def merge_update(self, other):
if other.d is None:
return
elif self.d is None:
self.d = other.d.copy()
else:
d1, d2 = self.d, other.d
new = {}
todo = list(set(d1) & set(d2))
while todo:
cur = todo.pop()
matches = [k for k in todo if d1[k] is d1[cur] and d2[k] is d2[cur]]
if not matches:
continue
new[cur] = self._newval()
for k in matches:
new[k] = new[cur]
todo = [k for k in todo if k not in new]
self.d = new
def copy(self): return EqualityData(self.d)
def __eq__(self, other):
if self.d is None or other.d is None:
return self.d is other.d
if self.d == other.d:
return True
if set(self.d) != set(other.d):
return False
match = {}
for k in self.d:
if match.setdefault(self.d[k], other.d[k]) != other.d[k]:
return False
return True
def __ne__(self, other): return not self == other
def __hash__(self): raise TypeError('unhashable type')
def calcEqualityData(graph):
graph.simplify()
blocks = graph.blocks
d = {b:[EqualityData()] for b in blocks}
d[graph.entry][0].initialize()
stack = [graph.entry]
dirty = set(blocks)
while stack:
block = stack.pop()
if block not in dirty:
continue
dirty.remove(block)
cur = d[block][0].copy()
e_out = EqualityData()
del d[block][1:]
for line_t, data in block.lines:
if line_t == 'def':
cur.handleAssign(*data)
d[block].append(cur.copy())
elif line_t == 'canthrow':
e_out.merge_update(cur)
for out, successors in [(e_out, block.e_successors), (cur, block.n_successors)]:
stack += successors
for suc in successors:
old = d[suc][0].copy()
d[suc][0].merge_update(out)
if old != d[suc][0]:
dirty.add(suc)
for block in blocks:
assert d[block][0].d is not None
assert not dirty
return d
class VarMergeInfo(object):
def __init__(self, graph, methodparams, isstatic):
self.info = {}
self.final, self.unmergeable, self.external = set(), set(), set()
self.equality = None # to be calculated later
self.graph = graph
self.pending_graph_replaces = {}
self.touched_vars = set()
# initialize variables and assignment data
for var in methodparams:
self._addvar(var)
self.external.update(methodparams)
if not isstatic:
self.final.add(methodparams[0])
for block in graph.blocks:
for line_t, data in block.lines:
if line_t == 'def':
self._addassign(data[0], data[1])
for caught in block.caught_excepts:
self._addvar(caught)
self.external.add(caught)
self.unmergeable.add(caught)
# initialization helper funcs
def _addvar(self, v):
return self.info.setdefault(v, VarInfo(len(self.info)))
def _addassign(self, v1, v2):
info = self._addvar(v1)
if v2 is not None:
info.defs.add(v2)
self._addvar(v2).rdefs.add(v1)
else:
info.extracount += 1
# process helper funcs
def iseq(self, block, index, v1, v2):
return self.equality[block][index].iseq(v1, v2)
def _doGraphReplacements(self):
self.graph.replace(self.pending_graph_replaces)
self.pending_graph_replaces = {}
self.touched_vars = set()
def compat(self, v1, v2, doeq):
if v1 in self.touched_vars or v2 in self.touched_vars:
self._doGraphReplacements()
blocks = self.graph.blocks
vok = {b:3 for b in blocks} # use bitmask v1ok = 1<<0, v2ok = 1<<1
stack = [b for b in blocks if v1 in b.vars or v2 in b.vars]
while stack:
block = stack.pop()
cur = vok[block]
e_out = 3
if v1 in block.vars or v2 in block.vars:
defcount = 0
for line_t, data in block.lines:
if line_t == 'use':
if (data == v1 and not cur & 1) or (data == v2 and not cur & 2):
return False
elif line_t == 'def':
defcount += 1
if data[0] == v1 and data[1] != v1:
cur = 1
elif data[0] == v2 and data[1] != v2:
cur = 2
if doeq and self.iseq(block, defcount, v1, v2):
cur = 3
elif line_t == 'canthrow':
e_out &= cur
else:
# v1 and v2 not touched in this block, so there is nothing to do
e_out = cur
for out, successors in [(e_out, block.e_successors), (cur, block.n_successors)]:
for suc in successors:
if vok[suc] & out != vok[suc]:
stack.append(suc)
vok[suc] &= out
return True
def process(self, replace, doeq):
final, unmergeable, external = self.final, self.unmergeable, self.external
d = self.info
work_q = [(info.priority(), var) for var, info in d.items()]
heapq.heapify(work_q)
dirty = set(d) - external
while work_q:
_, cur = heapq.heappop(work_q)
if (cur in external) or cur not in dirty:
continue
dirty.remove(cur)
candidate_set = d[cur].defs - unmergeable
if len(d[cur].defs) > 1 or d[cur].extracount > 0:
candidate_set = candidate_set - final
candidates = [v for v in candidate_set if v.dtype == cur.dtype]
candidates = sorted(candidates, key=lambda v:d[v].key)
assert cur not in candidates
# find first candidate that is actually compatible
for parent in candidates:
if self.compat(cur, parent, doeq):
break
else:
continue # no candidates found
replace[cur] = parent
self.pending_graph_replaces[cur] = parent
self.touched_vars.add(cur)
self.touched_vars.add(parent)
infc, infp = d[cur], d[parent]
# Be careful, there could be a loop with cur in parent.defs
infc.defs.remove(parent)
infc.rdefs.discard(parent)
infp.rdefs.remove(cur)
infp.defs.discard(cur)
for var in d[cur].rdefs:
d[var].defs.remove(cur)
d[var].defs.add(parent)
heapq.heappush(work_q, (d[var].priority(), var))
for var in d[cur].defs:
d[var].rdefs.remove(cur)
d[var].rdefs.add(parent)
d[parent].defs |= d[cur].defs
d[parent].rdefs |= d[cur].rdefs
d[parent].extracount += d[cur].extracount
del d[cur]
heapq.heappush(work_q, (d[parent].priority(), parent))
dirty.add(parent)
def processMain(self, replace):
self.process(replace, False)
self._doGraphReplacements()
self.equality = calcEqualityData(self.graph)
self.process(replace, True)
###############################################################################
def mergeVariables(root, isstatic, parameters):
# first, create CFG from the Java AST
graph = makeGraph(root)
mergeinfo = VarMergeInfo(graph, parameters, isstatic)
replace = {}
mergeinfo.processMain(replace)
flattenDict(replace)
return replace

View File

@ -0,0 +1,55 @@
reserved_identifiers = '''
abstract
assert
boolean
break
byte
case
catch
char
class
const
continue
default
do
double
else
enum
extends
false
final
finally
float
for
goto
if
implements
import
instanceof
int
interface
long
native
new
null
package
private
protected
public
return
short
static
strictfp
super
switch
synchronized
this
throw
throws
transient
true
try
void
volatile
while
'''.split()

View File

@ -0,0 +1,63 @@
import itertools
def update(self, items):
self.entryBlock = items[0].entryBlock
self.nodes = frozenset.union(*(i.nodes for i in items))
temp = set(self.nodes)
siter = itertools.chain.from_iterable(i.successors for i in items)
self.successors = [n for n in siter if not n in temp and not temp.add(n)]
class SEBlockItem(object):
def __init__(self, node):
self.successors = node.norm_suc_nl # don't include backedges or exceptional edges
self.node = node
self.nodes = frozenset([node])
self.entryBlock = node
def getScopes(self): return ()
class SEScope(object):
def __init__(self, items):
self.items = items
update(self, items)
def getScopes(self): return ()
class SEWhile(object):
def __init__(self, scope):
self.body = scope
update(self, [scope])
def getScopes(self): return self.body,
class SETry(object):
def __init__(self, tryscope, catchscope, toptts, catchvar):
self.scopes = tryscope, catchscope
self.toptts = toptts
self.catchvar = catchvar # none if ignored
update(self, self.scopes)
def getScopes(self): return self.scopes
class SEIf(object):
def __init__(self, head, newscopes):
assert len(newscopes) == 2
self.scopes = newscopes
self.head = head
update(self, [head] + newscopes)
def getScopes(self): return self.scopes
class SESwitch(object):
def __init__(self, head, newscopes):
self.scopes = newscopes
self.head = head
self.ordered = newscopes
update(self, [head] + newscopes)
jump = head.node.block.jump
keysets = {head.node.blockdict[b.key,False]:jump.reverse.get(b) for b in jump.getNormalSuccessors()}
assert keysets.values().count(None) == 1
self.ordered_keysets = [keysets[item.entryBlock] for item in newscopes]
def getScopes(self): return self.scopes

View File

@ -0,0 +1,27 @@
# double quote, backslash, and newlines are forbidden
ok_chars = " !#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]^_`abcdefghijklmnopqrstuvwxyz{|}~"
ok_chars = frozenset(ok_chars)
# these characters cannot use unicode escape codes due to the way Java escaping works
late_escape = {u'\u0009':r'\t', u'\u000a':r'\n', u'\u000d':r'\r', u'\u0022':r'\"', u'\u005c':r'\\'}
def escapeString(u):
if set(u) <= ok_chars:
return u
escaped = []
for c in u:
if c in ok_chars:
escaped.append(c)
elif c in late_escape:
escaped.append(late_escape[c])
else:
i = ord(c)
if i <= 0xFFFF:
escaped.append(r'\u{0:04x}'.format(i))
else:
i -= 0x10000
high = 0xD800 + (i>>10)
low = 0xDC00 + (i & 0x3FF)
escaped.append(r'\u{0:04x}\u{1:04x}'.format(high,low))
return ''.join(escaped)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,48 @@
from ..ssa import objtypes
from . import ast
# A simple throws declaration inferrer that only considers throw statements within the method
# this is mostly just useful to make sure the ExceptionHandlers test compiles
def _visit_statement(env, stmt):
if isinstance(stmt, ast.ThrowStatement):
return stmt.expr.dtype
result = objtypes.NullTT
if isinstance(stmt, ast.TryStatement):
caught_types = []
for catch, b in stmt.pairs:
caught_types.extend(objtypes.className(tn.tt) for tn in catch.typename.tnames)
if objtypes.ThrowableTT not in caught_types:
temp = _visit_statement(env, stmt.tryb)
if temp != objtypes.NullTT:
assert objtypes.dim(temp) == 0
name = objtypes.className(temp)
if not any(env.isSubclass(name, caught) for caught in caught_types):
result = temp
statements = zip(*stmt.pairs)[1]
elif isinstance(stmt, ast.StatementBlock):
statements = stmt.statements
else:
statements = stmt.getScopes()
for sub in statements:
if result == objtypes.ThrowableTT:
break
result = objtypes.commonSupertype(env, [result, _visit_statement(env, sub)])
if result != objtypes.NullTT:
if env.isSubclass(objtypes.className(result), 'java/lang/RuntimeException'):
return objtypes.NullTT
return result
def addSingle(env, meth_asts):
for meth in meth_asts:
if not meth.body:
continue
tt = _visit_statement(env, meth.body)
assert objtypes.commonSupertype(env, [tt, objtypes.ThrowableTT]) == objtypes.ThrowableTT
if tt != objtypes.NullTT:
meth.throws = ast.TypeName(tt)

View File

@ -0,0 +1,14 @@
# Override this to rename classes
class DefaultVisitor(object):
def visit(self, obj):
return obj.print_(self, self.visit)
# Experimental - don't use!
def toTree(self, obj):
if obj is None:
return None
return obj.tree(self, self.toTree)
def className(self, name): return name
def methodName(self, cls, name, desc): return name
def fieldName(self, cls, name, desc): return name

View File

@ -0,0 +1,115 @@
import collections
from . import bytecode
from .attributes_raw import fixAttributeNames, get_attributes_raw
from .classfileformat.reader import Reader
exceptionHandlerRaw = collections.namedtuple("exceptionHandlerRaw",
["start","end","handler","type_ind"])
class Code(object):
def __init__(self, method, bytestream, keepRaw):
self.method = method
self.class_ = method.class_
# Old versions use shorter fields for stack, locals, and code length
field_fmt = ">HHL" if self.class_.version > (45,2) else ">BBH"
self.stack, self.locals, codelen = bytestream.get(field_fmt)
# assert codelen > 0 and codelen < 65536
self.bytecode_raw = bytestream.getRaw(codelen)
self.codelen = codelen
except_cnt = bytestream.get('>H')
self.except_raw = [bytestream.get('>HHHH') for _ in range(except_cnt)]
self.except_raw = [exceptionHandlerRaw(*t) for t in self.except_raw]
attributes_raw = get_attributes_raw(bytestream)
assert bytestream.size() == 0
if self.except_raw:
assert self.stack >= 1
# print 'Parsing code for', method.name, method.descriptor, method.flags
codestream = Reader(data=self.bytecode_raw)
self.bytecode = bytecode.parseInstructions(codestream, self.isIdConstructor)
self.attributes = fixAttributeNames(attributes_raw, self.class_.cpool)
for e in self.except_raw:
assert e.start in self.bytecode
assert e.end == codelen or e.end in self.bytecode
assert e.handler in self.bytecode
if keepRaw:
self.attributes_raw = attributes_raw
# This is a callback passed to the bytecode parser to determine if a given method id represents a constructor
def isIdConstructor(self, methId):
args = self.class_.cpool.getArgsCheck('Method', methId)
return args[1] == '<init>'
def __str__(self): # pragma: no cover
lines = ['Stack: {}, Locals {}'.format(self.stack, self.locals)]
instructions = self.bytecode
lines += ['{}: {}'.format(i, bytecode.printInstruction(instructions[i])) for i in sorted(instructions)]
if self.except_raw:
lines += ['Exception Handlers:']
lines += map(str, self.except_raw)
return '\n'.join(lines)
class Method(object):
flagVals = {'PUBLIC':0x0001,
'PRIVATE':0x0002,
'PROTECTED':0x0004,
'STATIC':0x0008,
'FINAL':0x0010,
'SYNCHRONIZED':0x0020,
'BRIDGE':0x0040,
'VARARGS':0x0080,
'NATIVE':0x0100,
'ABSTRACT':0x0400,
'STRICTFP':0x0800,
'SYNTHETIC':0x1000,
}
def __init__(self, data, classFile, keepRaw):
self.class_ = classFile
cpool = self.class_.cpool
flags, name_id, desc_id, attributes_raw = data
self.name = cpool.getArgsCheck('Utf8', name_id)
self.descriptor = cpool.getArgsCheck('Utf8', desc_id)
# print 'Loading method ', self.name, self.descriptor
self.attributes = fixAttributeNames(attributes_raw, cpool)
self.flags = set(name for name, mask in Method.flagVals.items() if (mask & flags))
# Flags are ignored for <clinit>?
if self.name == '<clinit>':
self.flags = set(['STATIC'])
self._checkFlags()
self.static = 'STATIC' in self.flags
self.native = 'NATIVE' in self.flags
self.abstract = 'ABSTRACT' in self.flags
self.isConstructor = (self.name == '<init>')
self.code = self._loadCode(keepRaw)
if keepRaw:
self.attributes_raw = attributes_raw
self.name_id, self.desc_id = name_id, desc_id
def _checkFlags(self):
assert len(self.flags & set(('PRIVATE','PROTECTED','PUBLIC'))) <= 1
if 'ABSTRACT' in self.flags:
assert not self.flags & set(['SYNCHRONIZED', 'PRIVATE', 'FINAL', 'STRICT', 'STATIC', 'NATIVE'])
def _loadCode(self, keepRaw):
code_attrs = [a for a in self.attributes if a[0] == 'Code']
if not (self.native or self.abstract):
assert len(code_attrs) == 1
code_raw = code_attrs[0][1]
bytestream = Reader(code_raw)
return Code(self, bytestream, keepRaw)
assert not code_attrs
return None

View File

@ -0,0 +1,18 @@
import collections
import itertools
class NameGen(object):
def __init__(self, reserved=frozenset()):
self.counters = collections.defaultdict(itertools.count)
self.names = set(reserved)
def getPrefix(self, prefix, sep=''):
newname = prefix
while newname in self.names:
newname = prefix + sep + str(next(self.counters[prefix]))
self.names.add(newname)
return newname
def LabelGen(prefix='label'):
for i in itertools.count():
yield prefix + str(i)

View File

@ -0,0 +1,65 @@
ADD = 'add'
AND = 'and'
ANEWARRAY = 'anewarray'
ARRLEN = 'arrlen'
ARRLOAD = 'arrload'
ARRLOAD_OBJ = 'arrload_obj'
ARRSTORE = 'arrstore'
ARRSTORE_OBJ = 'arrstore_obj'
CHECKCAST = 'checkcast'
CONST = 'const'
CONSTNULL = 'constnull'
CONVERT = 'convert'
DIV = 'div'
DUP = 'dup'
DUP2 = 'dup2'
DUP2X1 = 'dup2x1'
DUP2X2 = 'dup2x2'
DUPX1 = 'dupx1'
DUPX2 = 'dupx2'
FCMP = 'fcmp'
GETFIELD = 'getfield'
GETSTATIC = 'getstatic'
GOTO = 'goto'
IF_A = 'if_a'
IF_ACMP = 'if_acmp'
IF_I = 'if_i'
IF_ICMP = 'if_icmp'
IINC = 'iinc'
INSTANCEOF = 'instanceof'
INVOKEDYNAMIC = 'invokedynamic'
INVOKEINIT = 'invokeinit'
INVOKEINTERFACE = 'invokeinterface'
INVOKESPECIAL = 'invokespecial'
INVOKESTATIC = 'invokestatic'
INVOKEVIRTUAL = 'invokevirtual'
JSR = 'jsr'
LCMP = 'lcmp'
LDC = 'ldc'
LOAD = 'load'
MONENTER = 'monenter'
MONEXIT = 'monexit'
MUL = 'mul'
MULTINEWARRAY = 'multinewarray'
NEG = 'neg'
NEW = 'new'
NEWARRAY = 'newarray'
NOP = 'nop'
OR = 'or'
POP = 'pop'
POP2 = 'pop2'
PUTFIELD = 'putfield'
PUTSTATIC = 'putstatic'
REM = 'rem'
RET = 'ret'
RETURN = 'return'
SHL = 'shl'
SHR = 'shr'
STORE = 'store'
SUB = 'sub'
SWAP = 'swap'
SWITCH = 'switch'
THROW = 'throw'
TRUNCATE = 'truncate'
USHR = 'ushr'
XOR = 'xor'

View File

@ -0,0 +1,191 @@
from __future__ import print_function
import collections
import errno
from functools import partial
import hashlib
import os
import os.path
import platform
import zipfile
# Various utility functions for the top level scripts (decompile.py, assemble.py, disassemble.py)
copyright = '''Krakatau Copyright (C) 2012-18 Robert Grosse
This program is provided as open source under the GNU General Public License.
See LICENSE.TXT for more details.
'''
_osname = platform.system().lower()
IS_WINDOWS = 'win' in _osname and 'darwin' not in _osname and 'cygwin' not in _osname
def findFiles(target, recursive, prefix):
if target.endswith('.jar'):
with zipfile.ZipFile(target, 'r') as archive:
return [name.encode('utf8') for name in archive.namelist() if name.endswith(prefix)]
else:
if recursive:
assert os.path.isdir(target)
targets = []
for root, dirs, files in os.walk(target):
targets += [os.path.join(root, fname) for fname in files if fname.endswith(prefix)]
return targets
else:
return [target]
def normalizeClassname(name):
if name.endswith('.class'):
name = name[:-6]
# Replacing backslashes is ugly since they can be in valid classnames too, but this seems the best option
return name.replace('\\','/').replace('.','/')
# Windows stuff
illegal_win_chars = frozenset('<>;:|?*\\/"%')
pref_disp_chars = frozenset('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_$0123456789')
# Prevent creating filename parts matching the legacy device filenames. While Krakatau can create these files
# just fine thanks to using \\?\ paths, the resulting files are impossible to open or delete in Windows Explorer
# or with similar tools, so they are a huge pain to deal with. Therefore, we don't generate them at all.
illegal_parts = frozenset(['CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8',
'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9'])
class PathSanitizer(object):
def __init__(self, base, suffix):
self.base = base
self.suffix = suffix
def is_part_ok(self, s, parents):
if not 1 <= len(s) <= self.MAX_PART_LEN:
return False
# avoid potential collision with hashed parts
if len(s) >= 66 and '__' in s:
return False
# . cannot appear in a valid class name, but might as well specifically exclude these, just in case
if s.startswith('.') or '..' in s:
return False
return '\x1f' < min(s) <= max(s) < '\x7f'
def hash(self, s, suffix):
left = ''.join(c for c in s if c in pref_disp_chars)
right = '__' + hashlib.sha256(s.encode('utf8')).hexdigest() + suffix
return left[:self.MAX_PART_LEN - len(right)] + right
def sanitize(self, path):
if isinstance(path, bytes):
path = path.decode()
oldparts = path.split('/')
newparts = []
for i, part in enumerate(oldparts):
suffix = self.suffix if i + 1 == len(oldparts) else ''
if self.is_part_ok(part + suffix, newparts):
newparts.append(part + suffix)
else:
newparts.append(self.hash(part, suffix))
result = self.format_path([self.base] + newparts)
if len(result) > self.MAX_PATH_LEN:
result = self.format_path([self.base, self.hash(path, self.suffix)])
assert result.endswith(self.suffix)
return result
class LinuxPathSanitizer(PathSanitizer):
MAX_PART_LEN = 255
MAX_PATH_LEN = 4095
def __init__(self, *args):
PathSanitizer.__init__(self, *args)
def format_path(self, parts):
return os.path.join(*parts)
class WindowsPathSanitizer(PathSanitizer):
MAX_PART_LEN = 255
MAX_PATH_LEN = 32000 # close enough
def __init__(self, *args):
PathSanitizer.__init__(self, *args)
# keep track of previous paths to detect case-insensitive collisions
self.prevs = collections.defaultdict(dict)
def is_part_ok(self, s, parents):
if not PathSanitizer.is_part_ok(self, s, parents):
return False
if s.upper() in illegal_parts:
return False
# make sure nothing in the current directory is a case insensitive collision
if self.prevs[tuple(parents)].setdefault(s.lower(), s) != s:
return False
return illegal_win_chars.isdisjoint(s)
def format_path(self, parts):
return '\\\\?\\' + '\\'.join(parts)
class DirectoryWriter(object):
def __init__(self, base_path, suffix):
if base_path is None:
base_path = os.getcwd()
else:
if not isinstance(base_path, str):
base_path = base_path.decode('utf8')
base_path = os.path.abspath(base_path)
if IS_WINDOWS:
self.makepath = WindowsPathSanitizer(base_path, suffix).sanitize
else:
self.makepath = LinuxPathSanitizer(base_path, suffix).sanitize
def write(self, cname, data):
out = self.makepath(cname)
dirpath = os.path.dirname(out)
try:
if dirpath:
os.makedirs(dirpath)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
mode = 'wb' if isinstance(data, bytes) else 'w'
with open(out, mode) as f:
f.write(data)
return out
def __enter__(self): return self
def __exit__(self, *args): pass
class JarWriter(object):
def __init__(self, base_path, suffix):
self.zip = zipfile.ZipFile(base_path, mode='w')
self.suffix = suffix
def write(self, cname, data):
info = zipfile.ZipInfo(cname + self.suffix, (1980, 1, 1, 0, 0, 0))
self.zip.writestr(info, data)
return 'zipfile'
def __enter__(self): self.zip.__enter__(); return self
def __exit__(self, *args): self.zip.__exit__(*args)
class MockWriter(object):
def __init__(self): self.results = []
def write(self, cname, data): self.results.append((cname, data))
def __enter__(self): return self
def __exit__(self, *args): pass
def makeWriter(base_path, suffix):
if base_path is not None:
if base_path.endswith('.zip') or base_path.endswith('.jar'):
return JarWriter(base_path, suffix)
return DirectoryWriter(base_path, suffix)
###############################################################################
def ignore(*args, **kwargs):
pass
class Logger(object):
def __init__(self, level):
lvl = ['info', 'warning'].index(level)
self.info = print if lvl <= 0 else ignore
self.warn = print if lvl <= 1 else ignore

View File

@ -0,0 +1 @@
from .graph import ssaFromVerified

View File

@ -0,0 +1,360 @@
import collections
import operator
from .. import opnames as vops
from ..verifier import verifier_types
from . import ssa_jumps, ssa_ops, subproc
from .blockmakerfuncs import ResultDict, instructionHandlers
from .ssa_types import BasicBlock, SSA_OBJECT, slots_t
def toBits(x): return [i for i in range(x.bit_length()) if x & (1 << i)]
# keys for special blocks created at the cfg entry and exit. Negative keys ensures they don't collide
ENTRY_KEY, RETURN_KEY, RETHROW_KEY = -1, -2, -3
def getUsedLocals(iNodes, iNodeD, exceptions):
# For every instruction, find which locals at that point may be used in the future
except_ranges = [(h, [node.key for node in iNodes if s <= node.key < e]) for s, e, h, i in exceptions]
old = collections.defaultdict(int)
while 1:
data = old.copy()
# Do one iteration
for node in reversed(iNodes):
used = reduce(operator.__or__, (data[key] for key in node.successors), 0)
if node.instruction[0] == vops.LOAD:
used |= 1 << node.instruction[2]
elif node.instruction[0] == vops.IINC:
used |= 1 << node.instruction[1]
elif node.instruction[0] == vops.STORE:
bits = 3 if node.instruction[1] in 'JD' else 1
used &= ~(bits << node.instruction[2])
elif node.instruction[0] == vops.RET:
# If local is not in mask, it will use the value from the jsr instead of the ret
mask = sum(1<<i for i in node.out_state.copy().maskFor(node.jsrTarget))
used &= mask
elif node.instruction[0] == vops.JSR and node.returnedFrom is not None:
retnode = iNodeD[node.returnedFrom]
assert node.successors == (retnode.jsrTarget,)
mask = sum(1<<i for i in iNodeD[node.returnedFrom].out_state.copy().maskFor(retnode.jsrTarget))
assert node.next_instruction is not None
used |= (data[node.next_instruction] & ~mask)
data[node.key] |= used
for hkey, region in except_ranges:
if data[hkey] != old[hkey]:
for key in region:
data[key] |= data[hkey]
if data == old:
break
old = data
# for entry point, every program argument is marked used so we can preserve input arguments for later
old[ENTRY_KEY] = (1 << len(iNodeD[0].state.locals)) - 1
return old
def slotsRvals(inslots):
stack = [(None if phi is None else phi.rval) for phi in inslots.stack]
newlocals = {i: phi.rval for i, phi in inslots.locals.items() if phi is not None}
return slots_t(stack=stack, locals=newlocals)
_jump_instrs = frozenset([vops.GOTO, vops.IF_A, vops.IF_ACMP, vops.IF_I, vops.IF_ICMP, vops.JSR, vops.SWITCH])
class BlockMaker(object):
def __init__(self, parent, iNodes, inputTypes, returnTypes, except_raw, opts):
self.parent = parent
self.blocks = []
self.blockd = {}
self.iNodes = [n for n in iNodes if n.visited]
self.iNodeD = {n.key: n for n in self.iNodes}
exceptions = [eh for eh in except_raw if eh.handler in self.iNodeD]
# Calculate which locals are actually live at any point
self.used_locals = getUsedLocals(self.iNodes, self.iNodeD, exceptions)
# create map of uninitialized -> initialized types so we can convert them
self.initMap = {}
for node in self.iNodes:
if node.op == vops.NEW:
self.initMap[node.stack_push[0]] = node.target_type
self.initMap[verifier_types.T_UNINIT_THIS] = verifier_types.T_OBJECT(parent.class_.name)
self.hasmonenter = any(node.instruction[0] == vops.MONENTER for node in self.iNodes)
self.entryBlock = self.makeBlockWithInslots(ENTRY_KEY, newlocals=inputTypes, stack=[])
self.returnBlock = self.makeBlockWithInslots(RETURN_KEY, newlocals=[], stack=returnTypes)
self.returnBlock.jump = ssa_jumps.Return(self, [phi.rval for phi in self.returnBlock.phis])
self.rethrowBlock = self.makeBlockWithInslots(RETHROW_KEY, newlocals=[], stack=[verifier_types.THROWABLE_INFO])
self.rethrowBlock.jump = ssa_jumps.Rethrow(self, [phi.rval for phi in self.rethrowBlock.phis])
# for ssagraph to copy
self.inputArgs = slotsRvals(self.entryBlock.inslots).localsAsList
self.entryBlock.phis = []
# We need to create stub blocks for every jump target so we can add them as successors during creation
jump_targets = [eh.handler for eh in exceptions]
for node in self.iNodes:
if node.instruction[0] in _jump_instrs:
jump_targets += node.successors
# add jsr fallthroughs too
if node.instruction[0] == vops.JSR and node.returnedFrom is not None:
jump_targets.append(node.next_instruction)
# for simplicity, keep jsr stuff in individual instruction blocks.
# Note that subproc.py will need to be modified if this is changed
for node in self.iNodes:
if node.instruction[0] in (vops.JSR, vops.RET):
jump_targets.append(node.key)
for key in jump_targets:
if key not in self.blockd: # jump_targets may have duplicates
self.makeBlock(key)
self.exceptionhandlers = []
for (start, end, handler, index) in exceptions:
catchtype = parent.getConstPoolArgs(index)[0] if index else 'java/lang/Throwable'
self.exceptionhandlers.append((start, end, self.blockd[handler], catchtype))
self.exceptionhandlers.append((0, 65536, self.rethrowBlock, 'java/lang/Throwable'))
# State variables for the append/builder loop
self.current_block = self.entryBlock
self.current_slots = slotsRvals(self.current_block.inslots)
for node in self.iNodes:
# First do a quick check if we have to start a new block
if not self._canContinueBlock(node):
self._startNewBlock(node.key)
vals, outslot_norm = self._getInstrLine(node)
# Disable exception pruning
if opts and not vals.jump:
dummyvals = ResultDict(line=ssa_ops.MagicThrow(self.parent))
if not self._canAppendInstrToCurrent(node.key, dummyvals):
self._startNewBlock(node.key)
assert self._canAppendInstrToCurrent(node.key, dummyvals)
self._appendInstr(node, dummyvals, self.current_slots, check_terminate=False)
vals, outslot_norm = self._getInstrLine(node)
if not self._canAppendInstrToCurrent(node.key, vals):
self._startNewBlock(node.key)
vals, outslot_norm = self._getInstrLine(node)
assert self._canAppendInstrToCurrent(node.key, vals)
self._appendInstr(node, vals, outslot_norm)
# do sanity checks
assert len(self.blocks) == len(self.blockd)
for block in self.blocks:
assert block.jump is not None and block.phis is not None
assert len(block.predecessors) == len(set(block.predecessors))
# cleanup temp vars
block.inslots = None
block.throwvars = None
block.chpairs = None
block.except_used = None
block.locals_at_except = None
def _canContinueBlock(self, node):
return (node.key not in self.blockd) and self.current_block.jump is None # fallthrough goto left as None
def _chPairsAt(self, address):
chpairs = []
for (start, end, handler, catchtype) in self.exceptionhandlers:
if start <= address < end:
chpairs.append((catchtype, handler))
return chpairs
def _canAppendInstrToCurrent(self, address, vals):
# If appending exception line to block with existing exceptions, make sure the handlers are the same
# Also make sure that locals are compatible with all other exceptions in the block
# If appending a jump, make sure there is no existing exceptions
block = self.current_block
if block.chpairs is not None:
if vals.jump:
return False
if vals.line is not None and vals.line.outException is not None:
chpairs = self._chPairsAt(address)
if chpairs != block.chpairs:
return False
newlocals = {i: self.current_slots.locals[i] for i in block.except_used}
return newlocals == block.locals_at_except
assert block.jump is None
return True
def pruneUnused(self, key, newlocals):
used = toBits(self.used_locals[key])
return {i: newlocals[i] for i in used}
def _startNewBlock(self, key):
''' We can't continue appending to the current block, so start a new one (or use existing one at location) '''
# Make new block
if key not in self.blockd:
self.makeBlock(key)
# Finish current block
block = self.current_block
curslots = self.current_slots
assert block.key != key
if block.jump is None:
if block.chpairs is not None:
assert block.throwvars
self._addOnException(block, self.blockd[key], curslots)
else:
assert not block.throwvars
block.jump = ssa_jumps.Goto(self.parent, self.blockd[key])
if curslots is not None:
self.mergeIn((block, False), key, curslots)
# Update state
self.current_block = self.blockd[key]
self.current_slots = slotsRvals(self.current_block.inslots)
def _getInstrLine(self, iNode):
parent, initMap = self.parent, self.initMap
inslots = self.current_slots
instr = iNode.instruction
# internal variables won't have any preset type info associated, so we should add in the info from the verifier
assert len(inslots.stack) == len(iNode.state.stack)
for i, ivar in enumerate(inslots.stack):
if ivar and ivar.type == SSA_OBJECT and ivar.decltype is None:
parent.setObjVarData(ivar, iNode.state.stack[i], initMap)
for i, ivar in inslots.locals.items():
if ivar and ivar.type == SSA_OBJECT and ivar.decltype is None:
parent.setObjVarData(ivar, iNode.state.locals[i], initMap)
vals = instructionHandlers[instr[0]](self, inslots, iNode)
newstack = vals.newstack if vals.newstack is not None else inslots.stack
newlocals = vals.newlocals if vals.newlocals is not None else inslots.locals
outslot_norm = slots_t(locals=newlocals, stack=newstack)
return vals, outslot_norm
def _addOnException(self, block, fallthrough, outslot_norm):
parent = self.parent
assert block.throwvars and block.chpairs is not None
ephi = ssa_ops.ExceptionPhi(parent, block.throwvars)
block.lines.append(ephi)
assert block.jump is None
block.jump = ssa_jumps.OnException(parent, ephi.outException, block.chpairs, fallthrough)
outslot_except = slots_t(locals=block.locals_at_except, stack=[ephi.outException])
for suc in block.jump.getExceptSuccessors():
self.mergeIn((block, True), suc.key, outslot_except)
def _appendInstr(self, iNode, vals, outslot_norm, check_terminate=True):
parent = self.parent
block = self.current_block
line, jump = vals.line, vals.jump
if line is not None:
block.lines.append(line)
assert block.jump is None
block.jump = jump
if line is not None and line.outException is not None:
block.throwvars.append(line.outException)
inslots = self.current_slots
if block.chpairs is None:
block.chpairs = self._chPairsAt(iNode.key)
temp = (self.used_locals[h.key] for t, h in block.chpairs)
block.except_used = toBits(reduce(operator.__or__, temp, 0))
block.locals_at_except = {i: inslots.locals[i] for i in block.except_used}
if check_terminate:
# Return and Throw must be immediately ended because they don't have normal fallthrough
# CheckCast must terminate block because cast type hack later on requires casts to be at end of block
if iNode.instruction[0] in (vops.RETURN, vops.THROW) or isinstance(line, ssa_ops.CheckCast):
fallthrough = self.getExceptFallthrough(iNode)
self._addOnException(block, fallthrough, outslot_norm)
if block.jump is None:
unmerged_slots = outslot_norm
else:
assert isinstance(block.jump, ssa_jumps.OnException) or not block.throwvars
unmerged_slots = None
# Make sure that branch targets are distinct, since this is assumed everywhere
# Only necessary for if statements as the other jumps merge targets automatically
# If statements with both branches jumping to same target are replaced with gotos
block.jump = block.jump.reduceSuccessors([])
if isinstance(block.jump, subproc.ProcCallOp):
self.mergeJSROut(iNode, block, outslot_norm)
else:
for suc in block.jump.getNormalSuccessors():
self.mergeIn((block, False), suc.key, outslot_norm)
self.current_slots = unmerged_slots
assert (block.chpairs is None) == (block.except_used is None) == (block.locals_at_except is None)
def mergeIn(self, from_key, target_key, outslots):
inslots = self.blockd[target_key].inslots
assert len(inslots.stack) == len(outslots.stack)
for i, phi in enumerate(inslots.stack):
if phi is not None:
phi.add(from_key, outslots.stack[i])
for i, phi in inslots.locals.items():
if phi is not None:
phi.add(from_key, outslots.locals[i])
self.blockd[target_key].predecessors.append(from_key)
## Block Creation #########################################
def _makePhiFromVType(self, block, vt):
var = self.parent.makeVarFromVtype(vt, self.initMap)
return None if var is None else ssa_ops.Phi(block, var)
def makeBlockWithInslots(self, key, newlocals, stack):
assert key not in self.blockd
block = BasicBlock(key)
self.blocks.append(block)
self.blockd[key] = block
# create inslot phis
stack = [self._makePhiFromVType(block, vt) for vt in stack]
newlocals = dict(enumerate(self._makePhiFromVType(block, vt) for vt in newlocals))
newlocals = self.pruneUnused(key, newlocals)
block.inslots = slots_t(locals=newlocals, stack=stack)
block.phis = [phi for phi in stack + block.inslots.localsAsList if phi is not None]
return block
def makeBlock(self, key):
node = self.iNodeD[key]
return self.makeBlockWithInslots(key, node.state.locals, node.state.stack)
###########################################################
def getExceptFallthrough(self, iNode):
vop = iNode.instruction[0]
if vop == vops.RETURN:
return self.blockd[RETURN_KEY]
elif vop == vops.THROW:
return None
key = iNode.successors[0]
if key not in self.blockd:
self.makeBlock(key)
return self.blockd[key]
def mergeJSROut(self, jsrnode, block, outslot_norm):
retnode = self.iNodeD[jsrnode.returnedFrom]
jump = block.jump
target_key, ft_key = jump.target.key, jump.fallthrough.key
assert ft_key == jsrnode.next_instruction
# first merge regular jump to target
self.mergeIn((block, False), target_key, outslot_norm)
# create merged outslots for fallthrough
fromcall = jump.output
mask = [mask for key, mask in retnode.state.masks if key == target_key][0]
skiplocs = fromcall.locals
retlocs = outslot_norm.locals
merged = {i: (skiplocs.get(i) if i in mask else retlocs.get(i)) for i in (mask | frozenset(retlocs))}
# jump.debug_skipvars = set(merged) - set(locals)
outslot_merged = slots_t(locals=merged, stack=fromcall.stack)
# merge merged outputs with fallthrough
self.mergeIn((block, False), ft_key, outslot_merged)

View File

@ -0,0 +1,461 @@
from .. import opnames as vops
from ..verifier.descriptors import parseFieldDescriptor, parseMethodDescriptor
from . import objtypes, ssa_jumps, ssa_ops, subproc
from .ssa_types import SSA_DOUBLE, SSA_FLOAT, SSA_INT, SSA_LONG, SSA_OBJECT, slots_t
_charToSSAType = {'D':SSA_DOUBLE, 'F':SSA_FLOAT, 'I':SSA_INT, 'J':SSA_LONG,
'B':SSA_INT, 'C':SSA_INT, 'S':SSA_INT}
def getCategory(c): return 2 if c in 'JD' else 1
class ResultDict(object):
def __init__(self, line=None, jump=None, newstack=None, newlocals=None):
self.line = line
self.jump = jump
self.newstack = newstack
self.newlocals = newlocals
##############################################################################
def makeConstVar(parent, type_, val):
var = parent.makeVariable(type_)
var.const = val
return var
def parseArrOrClassName(desc):
# Accept either a class or array descriptor or a raw class name.
if desc.startswith('[') or desc.endswith(';'):
vtypes = parseFieldDescriptor(desc, unsynthesize=False)
tt = objtypes.verifierToSynthetic(vtypes[0])
else:
tt = objtypes.TypeTT(desc, 0)
return tt
def _floatOrIntMath(fop, iop):
def math1(maker, input_, iNode):
cat = getCategory(iNode.instruction[1])
isfloat = (iNode.instruction[1] in 'DF')
op = fop if isfloat else iop
args = input_.stack[-cat*2::cat]
line = op(maker.parent, args)
newstack = input_.stack[:-2*cat] + [line.rval] + [None]*(cat-1)
return ResultDict(line=line, newstack=newstack)
return math1
def _intMath(op, isShift):
def math2(maker, input_, iNode):
cat = getCategory(iNode.instruction[1])
# some ops (i.e. shifts) always take int as second argument
size = cat+1 if isShift else cat+cat
args = input_.stack[-size::cat]
line = op(maker.parent, args)
newstack = input_.stack[:-size] + [line.rval] + [None]*(cat-1)
return ResultDict(line=line, newstack=newstack)
return math2
##############################################################################
def _anewarray(maker, input_, iNode):
name = maker.parent.getConstPoolArgs(iNode.instruction[1])[0]
tt = parseArrOrClassName(name)
line = ssa_ops.NewArray(maker.parent, input_.stack[-1], tt)
newstack = input_.stack[:-1] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _arrlen(maker, input_, iNode):
line = ssa_ops.ArrLength(maker.parent, input_.stack[-1:])
newstack = input_.stack[:-1] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _arrload(maker, input_, iNode):
type_ = _charToSSAType[iNode.instruction[1]]
cat = getCategory(iNode.instruction[1])
line = ssa_ops.ArrLoad(maker.parent, input_.stack[-2:], type_)
newstack = input_.stack[:-2] + [line.rval] + [None]*(cat-1)
return ResultDict(line=line, newstack=newstack)
def _arrload_obj(maker, input_, iNode):
line = ssa_ops.ArrLoad(maker.parent, input_.stack[-2:], SSA_OBJECT)
newstack = input_.stack[:-2] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _arrstore(maker, input_, iNode):
if getCategory(iNode.instruction[1]) > 1:
newstack, args = input_.stack[:-4], input_.stack[-4:-1]
arr_vt, ind_vt = iNode.state.stack[-4:-2]
else:
newstack, args = input_.stack[:-3], input_.stack[-3:]
arr_vt, ind_vt = iNode.state.stack[-3:-1]
line = ssa_ops.ArrStore(maker.parent, args)
# Check if we can prune the exception early because the
# array size and index are known constants
if arr_vt.const is not None and ind_vt.const is not None:
if 0 <= ind_vt.const < arr_vt.const:
line.outException = None
return ResultDict(line=line, newstack=newstack)
def _arrstore_obj(maker, input_, iNode):
line = ssa_ops.ArrStore(maker.parent, input_.stack[-3:])
newstack = input_.stack[:-3]
return ResultDict(line=line, newstack=newstack)
def _checkcast(maker, input_, iNode):
index = iNode.instruction[1]
desc = maker.parent.getConstPoolArgs(index)[0]
tt = parseArrOrClassName(desc)
line = ssa_ops.CheckCast(maker.parent, tt, input_.stack[-1:])
return ResultDict(line=line)
def _const(maker, input_, iNode):
ctype, val = iNode.instruction[1:]
cat = getCategory(ctype)
type_ = _charToSSAType[ctype]
var = makeConstVar(maker.parent, type_, val)
newstack = input_.stack + [var] + [None]*(cat-1)
return ResultDict(newstack=newstack)
def _constnull(maker, input_, iNode):
var = makeConstVar(maker.parent, SSA_OBJECT, 'null')
var.decltype = objtypes.NullTT
newstack = input_.stack + [var]
return ResultDict(newstack=newstack)
def _convert(maker, input_, iNode):
src_c, dest_c = iNode.instruction[1:]
src_cat, dest_cat = getCategory(src_c), getCategory(dest_c)
stack, arg = input_.stack[:-src_cat], input_.stack[-src_cat]
line = ssa_ops.Convert(maker.parent, arg, _charToSSAType[src_c], _charToSSAType[dest_c])
newstack = stack + [line.rval] + [None]*(dest_cat-1)
return ResultDict(line=line, newstack=newstack)
def _fcmp(maker, input_, iNode):
op, c, NaN_val = iNode.instruction
cat = getCategory(c)
args = input_.stack[-cat*2::cat]
line = ssa_ops.FCmp(maker.parent, args, NaN_val)
newstack = input_.stack[:-cat*2] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _field_access(maker, input_, iNode):
index = iNode.instruction[1]
target, name, desc = maker.parent.getConstPoolArgs(index)
cat = len(parseFieldDescriptor(desc))
argcnt = cat if 'put' in iNode.instruction[0] else 0
if not 'static' in iNode.instruction[0]:
argcnt += 1
splitInd = len(input_.stack) - argcnt
args = [x for x in input_.stack[splitInd:] if x is not None]
line = ssa_ops.FieldAccess(maker.parent, iNode.instruction, (target, name, desc), args=args)
newstack = input_.stack[:splitInd] + line.returned
return ResultDict(line=line, newstack=newstack)
def _goto(maker, input_, iNode):
jump = ssa_jumps.Goto(maker.parent, maker.blockd[iNode.successors[0]])
return ResultDict(jump=jump)
def _if_a(maker, input_, iNode):
null = makeConstVar(maker.parent, SSA_OBJECT, 'null')
null.decltype = objtypes.NullTT
jump = ssa_jumps.If(maker.parent, iNode.instruction[1], map(maker.blockd.get, iNode.successors), (input_.stack[-1], null))
newstack = input_.stack[:-1]
return ResultDict(jump=jump, newstack=newstack)
def _if_i(maker, input_, iNode):
zero = makeConstVar(maker.parent, SSA_INT, 0)
jump = ssa_jumps.If(maker.parent, iNode.instruction[1], map(maker.blockd.get, iNode.successors), (input_.stack[-1], zero))
newstack = input_.stack[:-1]
return ResultDict(jump=jump, newstack=newstack)
def _if_cmp(maker, input_, iNode):
jump = ssa_jumps.If(maker.parent, iNode.instruction[1], map(maker.blockd.get, iNode.successors), input_.stack[-2:])
newstack = input_.stack[:-2]
return ResultDict(jump=jump, newstack=newstack)
def _iinc(maker, input_, iNode):
_, index, amount = iNode.instruction
oldval = input_.locals[index]
constval = makeConstVar(maker.parent, SSA_INT, amount)
line = ssa_ops.IAdd(maker.parent, (oldval, constval))
newlocals = input_.locals.copy()
newlocals[index] = line.rval
return ResultDict(line=line, newlocals=newlocals)
def _instanceof(maker, input_, iNode):
index = iNode.instruction[1]
desc = maker.parent.getConstPoolArgs(index)[0]
tt = parseArrOrClassName(desc)
line = ssa_ops.InstanceOf(maker.parent, tt, input_.stack[-1:])
newstack = input_.stack[:-1] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _invoke(maker, input_, iNode):
index = iNode.instruction[1]
target, name, desc = maker.parent.getConstPoolArgs(index)
target_tt = parseArrOrClassName(target)
argcnt = len(parseMethodDescriptor(desc)[0])
if not 'static' in iNode.instruction[0]:
argcnt += 1
splitInd = len(input_.stack) - argcnt
# If we are an initializer, store a copy of the uninitialized verifier type so the Java decompiler can patch things up later
isThisCtor = iNode.isThisCtor if iNode.op == vops.INVOKEINIT else False
args = [x for x in input_.stack[splitInd:] if x is not None]
line = ssa_ops.Invoke(maker.parent, iNode.instruction, (target, name, desc),
args=args, isThisCtor=isThisCtor, target_tt=target_tt)
newstack = input_.stack[:splitInd] + line.returned
return ResultDict(line=line, newstack=newstack)
def _invoke_dynamic(maker, input_, iNode):
index = iNode.instruction[1]
desc = maker.parent.getConstPoolArgs(index)[2]
argcnt = len(parseMethodDescriptor(desc)[0])
splitInd = len(input_.stack) - argcnt
args = [x for x in input_.stack[splitInd:] if x is not None]
line = ssa_ops.InvokeDynamic(maker.parent, desc, args)
newstack = input_.stack[:splitInd] + line.returned
return ResultDict(line=line, newstack=newstack)
def _jsr(maker, input_, iNode):
newstack = input_.stack + [None]
if iNode.returnedFrom is None:
jump = ssa_jumps.Goto(maker.parent, maker.blockd[iNode.successors[0]])
return ResultDict(newstack=newstack, jump=jump)
# create output variables from callop to represent vars received from ret.
# We can use {} for initMap since there will never be unintialized types here
retnode = maker.iNodeD[iNode.returnedFrom]
stack = [maker.parent.makeVarFromVtype(vt, {}) for vt in retnode.out_state.stack]
newlocals = dict(enumerate(maker.parent.makeVarFromVtype(vt, {}) for vt in retnode.out_state.locals))
newlocals = maker.pruneUnused(retnode.key, newlocals)
out_slots = slots_t(locals=newlocals, stack=stack)
# Simply store the data for now and fix things up once all the blocks are created
jump = subproc.ProcCallOp(maker.blockd[iNode.successors[0]], maker.blockd[iNode.next_instruction], input_, out_slots)
return ResultDict(jump=jump, newstack=newstack)
def _lcmp(maker, input_, iNode):
args = input_.stack[-4::2]
line = ssa_ops.ICmp(maker.parent, args)
newstack = input_.stack[:-4] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _ldc(maker, input_, iNode):
index, cat = iNode.instruction[1:]
entry_type = maker.parent.getConstPoolType(index)
args = maker.parent.getConstPoolArgs(index)
var = None
if entry_type == 'String':
var = makeConstVar(maker.parent, SSA_OBJECT, args[0])
var.decltype = objtypes.StringTT
elif entry_type == 'Int':
var = makeConstVar(maker.parent, SSA_INT, args[0])
elif entry_type == 'Long':
var = makeConstVar(maker.parent, SSA_LONG, args[0])
elif entry_type == 'Float':
var = makeConstVar(maker.parent, SSA_FLOAT, args[0])
elif entry_type == 'Double':
var = makeConstVar(maker.parent, SSA_DOUBLE, args[0])
elif entry_type == 'Class':
var = makeConstVar(maker.parent, SSA_OBJECT, parseArrOrClassName(args[0]))
var.decltype = objtypes.ClassTT
# Todo - handle MethodTypes and MethodHandles?
assert var
newstack = input_.stack + [var] + [None]*(cat-1)
return ResultDict(newstack=newstack)
def _load(maker, input_, iNode):
cat = getCategory(iNode.instruction[1])
index = iNode.instruction[2]
newstack = input_.stack + [input_.locals[index]] + [None]*(cat-1)
return ResultDict(newstack=newstack)
def _monitor(maker, input_, iNode):
isExit = 'exit' in iNode.instruction[0]
line = ssa_ops.Monitor(maker.parent, input_.stack[-1:], isExit)
newstack = input_.stack[:-1]
return ResultDict(line=line, newstack=newstack)
def _multinewarray(maker, input_, iNode):
op, index, dim = iNode.instruction
name = maker.parent.getConstPoolArgs(index)[0]
tt = parseArrOrClassName(name)
assert objtypes.dim(tt) >= dim
line = ssa_ops.MultiNewArray(maker.parent, input_.stack[-dim:], tt)
newstack = input_.stack[:-dim] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _neg(maker, input_, iNode):
cat = getCategory(iNode.instruction[1])
arg = input_.stack[-cat:][0]
if (iNode.instruction[1] in 'DF'):
line = ssa_ops.FNeg(maker.parent, [arg])
else: # for integers, we can just write -x as 0 - x
zero = makeConstVar(maker.parent, arg.type, 0)
line = ssa_ops.ISub(maker.parent, [zero,arg])
newstack = input_.stack[:-cat] + [line.rval] + [None]*(cat-1)
return ResultDict(line=line, newstack=newstack)
def _new(maker, input_, iNode):
index = iNode.instruction[1]
classname = maker.parent.getConstPoolArgs(index)[0]
if classname.endswith(';'):
classname = classname[1:-1]
line = ssa_ops.New(maker.parent, classname, iNode.key)
newstack = input_.stack + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _newarray(maker, input_, iNode):
vtypes = parseFieldDescriptor(iNode.instruction[1], unsynthesize=False)
tt = objtypes.verifierToSynthetic(vtypes[0])
line = ssa_ops.NewArray(maker.parent, input_.stack[-1], tt)
newstack = input_.stack[:-1] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def _nop(maker, input_, iNode):
return ResultDict()
def _ret(maker, input_, iNode):
jump = subproc.DummyRet(input_, maker.blockd[iNode.jsrTarget])
return ResultDict(jump=jump)
def _return(maker, input_, iNode):
# Our special return block expects only the return values on the stack
rtype = iNode.instruction[1]
if rtype is None:
newstack = []
else:
newstack = input_.stack[-getCategory(rtype):]
# TODO: enable once structuring is smarter
# if not maker.hasmonenter:
# jump = ssa_jumps.Goto(maker.parent, maker.returnBlock)
# return ResultDict(jump=jump, newstack=newstack)
line = ssa_ops.TryReturn(maker.parent)
return ResultDict(line=line, newstack=newstack)
def _store(maker, input_, iNode):
cat = getCategory(iNode.instruction[1])
index = iNode.instruction[2]
newlocals = input_.locals.copy()
newlocals[index] = input_.stack[-cat]
newstack = input_.stack[:-cat]
return ResultDict(newstack=newstack, newlocals=newlocals)
def _switch(maker, input_, iNode):
default, raw_table = iNode.instruction[1:3]
table = [(k, maker.blockd[v]) for k,v in raw_table]
jump = ssa_jumps.Switch(maker.parent, maker.blockd[default], table, input_.stack[-1:])
newstack = input_.stack[:-1]
return ResultDict(jump=jump, newstack=newstack)
def _throw(maker, input_, iNode):
line = ssa_ops.Throw(maker.parent, input_.stack[-1:])
return ResultDict(line=line, newstack=[])
def _truncate(maker, input_, iNode):
dest_c = iNode.instruction[1]
signed, width = {'B':(True, 8), 'C':(False, 16), 'S':(True, 16)}[dest_c]
line = ssa_ops.Truncate(maker.parent, input_.stack[-1], signed=signed, width=width)
newstack = input_.stack[:-1] + [line.rval]
return ResultDict(line=line, newstack=newstack)
def genericStackUpdate(maker, input_, iNode):
n = iNode.pop_amount
stack = input_.stack
stack, popped = stack[:-n], stack[-n:]
for i in iNode.stack_code:
stack.append(popped[i])
return ResultDict(newstack=stack)
instructionHandlers = {
vops.ADD: _floatOrIntMath(ssa_ops.FAdd, ssa_ops.IAdd),
vops.AND: _intMath(ssa_ops.IAnd, isShift=False),
vops.ANEWARRAY: _anewarray,
vops.ARRLEN: _arrlen,
vops.ARRLOAD: _arrload,
vops.ARRLOAD_OBJ: _arrload_obj,
vops.ARRSTORE: _arrstore,
vops.ARRSTORE_OBJ: _arrstore_obj,
vops.CHECKCAST: _checkcast,
vops.CONST: _const,
vops.CONSTNULL: _constnull,
vops.CONVERT: _convert,
vops.DIV: _floatOrIntMath(ssa_ops.FDiv, ssa_ops.IDiv),
vops.FCMP: _fcmp,
vops.GETSTATIC: _field_access,
vops.GETFIELD: _field_access,
vops.GOTO: _goto,
vops.IF_A: _if_a,
vops.IF_ACMP: _if_cmp, # cmp works on ints or objs
vops.IF_I: _if_i,
vops.IF_ICMP: _if_cmp,
vops.IINC: _iinc,
vops.INSTANCEOF: _instanceof,
vops.INVOKEINIT: _invoke,
vops.INVOKEINTERFACE: _invoke,
vops.INVOKESPECIAL: _invoke,
vops.INVOKESTATIC: _invoke,
vops.INVOKEVIRTUAL: _invoke,
vops.INVOKEDYNAMIC: _invoke_dynamic,
vops.JSR: _jsr,
vops.LCMP: _lcmp,
vops.LDC: _ldc,
vops.LOAD: _load,
vops.MONENTER: _monitor,
vops.MONEXIT: _monitor,
vops.MULTINEWARRAY: _multinewarray,
vops.MUL: _floatOrIntMath(ssa_ops.FMul, ssa_ops.IMul),
vops.NEG: _neg,
vops.NEW: _new,
vops.NEWARRAY: _newarray,
vops.NOP: _nop,
vops.OR: _intMath(ssa_ops.IOr, isShift=False),
vops.PUTSTATIC: _field_access,
vops.PUTFIELD: _field_access,
vops.REM: _floatOrIntMath(ssa_ops.FRem, ssa_ops.IRem),
vops.RET: _ret,
vops.RETURN: _return,
vops.SHL: _intMath(ssa_ops.IShl, isShift=True),
vops.SHR: _intMath(ssa_ops.IShr, isShift=True),
vops.STORE: _store,
vops.SUB: _floatOrIntMath(ssa_ops.FSub, ssa_ops.ISub),
vops.SWITCH: _switch,
vops.THROW: _throw,
vops.TRUNCATE: _truncate,
vops.USHR: _intMath(ssa_ops.IUshr, isShift=True),
vops.XOR: _intMath(ssa_ops.IXor, isShift=False),
vops.SWAP: genericStackUpdate,
vops.POP: genericStackUpdate,
vops.POP2: genericStackUpdate,
vops.DUP: genericStackUpdate,
vops.DUPX1: genericStackUpdate,
vops.DUPX2: genericStackUpdate,
vops.DUP2: genericStackUpdate,
vops.DUP2X1: genericStackUpdate,
vops.DUP2X2: genericStackUpdate,
}

View File

@ -0,0 +1,66 @@
import collections, itertools
from ... import floatutil
from .. import objtypes
from .int_c import IntConstraint
from .float_c import FloatConstraint
from .obj_c import ObjectConstraint
from ..ssa_types import SSA_INT, SSA_LONG, SSA_FLOAT, SSA_DOUBLE, SSA_OBJECT
# joins become more precise (intersection), meets become more general (union)
# Join currently supports joining a max of two constraints
# Meet assumes all inputs are not None
def join(*cons):
if None in cons:
return None
return cons[0].join(*cons[1:])
def meet(*cons):
if not cons:
return None
return cons[0].meet(*cons[1:])
def fromConstant(env, var):
ssa_type = var.type
cval = var.const
if ssa_type[0] == SSA_INT[0]:
return IntConstraint.const(ssa_type[1], cval)
elif ssa_type[0] == SSA_FLOAT[0]:
xt = floatutil.fromRawFloat(ssa_type[1], cval)
return FloatConstraint.const(ssa_type[1], xt)
elif ssa_type[0] == SSA_OBJECT[0]:
if var.decltype == objtypes.NullTT:
return ObjectConstraint.constNull(env)
return ObjectConstraint.fromTops(env, *objtypes.declTypeToActual(env, var.decltype))
_bots = {
SSA_INT: IntConstraint.bot(SSA_INT[1]),
SSA_LONG: IntConstraint.bot(SSA_LONG[1]),
SSA_FLOAT: FloatConstraint.bot(SSA_FLOAT[1]),
SSA_DOUBLE: FloatConstraint.bot(SSA_DOUBLE[1]),
}
def fromVariable(env, var):
if var.const is not None:
return fromConstant(env, var)
ssa_type = var.type
try:
return _bots[ssa_type]
except KeyError:
assert ssa_type == SSA_OBJECT
isnew = var.uninit_orig_num is not None
if var.decltype is not None:
if var.decltype == objtypes.NullTT:
return ObjectConstraint.constNull(env)
return ObjectConstraint.fromTops(env, *objtypes.declTypeToActual(env, var.decltype), nonnull=isnew)
else:
return ObjectConstraint.fromTops(env, [objtypes.ObjectTT], [], nonnull=isnew)
OpReturnInfo = collections.namedtuple('OpReturnInfo', ['rval', 'eval', 'must_throw'])
def returnOrThrow(rval, eval): return OpReturnInfo(rval, eval, False)
def maybeThrow(eval): return OpReturnInfo(None, eval, False)
def throw(eval): return OpReturnInfo(None, eval, True)
def return_(rval): return OpReturnInfo(rval, None, False)

View File

@ -0,0 +1,59 @@
from ... import floatutil as fu
from ..mixin import ValueType
SPECIALS = frozenset((fu.NAN, fu.INF, fu.NINF, fu.ZERO, fu.NZERO))
def botRange(size):
mbits, emin, emax = size
mag = (1<<(mbits+1))-1, emax-mbits
return (-1,mag), (1,mag)
class FloatConstraint(ValueType):
def __init__(self, size, finite, special):
self.size = size
self.finite = finite
self.spec = special
self.isBot = (special == SPECIALS) and (finite == botRange(size))
@staticmethod
def const(size, val):
if val in SPECIALS:
return FloatConstraint(size, (None, None), frozenset([val]))
return FloatConstraint(size, (val, val), frozenset())
@staticmethod
def bot(size):
finite = botRange(size)
return FloatConstraint(size, finite, SPECIALS)
def _key(self): return self.finite, self.spec
def join(*cons): # more precise (intersection)
spec = frozenset.intersection(*[c.spec for c in cons])
ranges = [c.finite for c in cons]
if (None, None) in ranges:
xmin = xmax = None
else:
mins, maxs = zip(*ranges)
xmin = max(mins, key=fu.sortkey)
xmax = min(maxs, key=fu.sortkey)
if fu.sortkey(xmax) < fu.sortkey(xmin):
xmin = xmax = None
if not xmin and not spec:
return None
return FloatConstraint(cons[0].size, (xmin, xmax), spec)
def meet(*cons):
spec = frozenset.union(*[c.spec for c in cons])
ranges = [c.finite for c in cons if c.finite != (None,None)]
if ranges:
mins, maxs = zip(*ranges)
xmin = min(mins, key=fu.sortkey)
xmax = max(maxs, key=fu.sortkey)
else:
xmin = xmax = None
return FloatConstraint(cons[0].size, (xmin, xmax), spec)

View File

@ -0,0 +1,48 @@
from ..mixin import ValueType
class IntConstraint(ValueType):
__slots__ = "width min max".split()
def __init__(self, width, min_, max_):
self.width = width
self.min = min_
self.max = max_
# self.isBot = (-min_ == max_+1 == (1<<width)//2)
@staticmethod
def range(width, min_, max_):
if min_ > max_:
return None
return IntConstraint(width, min_, max_)
@staticmethod
def const(width, val):
return IntConstraint(width, val, val)
@staticmethod
def bot(width):
return IntConstraint(width, -1<<(width-1), (1<<(width-1))-1)
def _key(self): return self.min, self.max
def join(*cons):
xmin = max(c.min for c in cons)
xmax = min(c.max for c in cons)
if xmin > xmax:
return None
res = IntConstraint(cons[0].width, xmin, xmax)
return cons[0] if cons[0] == res else res
def meet(*cons):
xmin = min(c.min for c in cons)
xmax = max(c.max for c in cons)
return IntConstraint(cons[0].width, xmin, xmax)
def __str__(self): # pragma: no cover
t = 'Int' if self.width == 32 else 'Long'
if self.min == self.max:
return '{}({})'.format(t, self.min)
elif self == self.bot(self.width):
return t
return '{}({}, {})'.format(t, self.min, self.max)
__repr__ = __str__

View File

@ -0,0 +1,6 @@
class ValueType(object):
'''Define _key() and inherit from this class to implement comparison and hashing'''
# def __init__(self, *args, **kwargs): super(ValueType, self).__init__(*args, **kwargs)
def __eq__(self, other): return type(self) == type(other) and self._key() == other._key()
def __ne__(self, other): return type(self) != type(other) or self._key() != other._key()
def __hash__(self): return hash(self._key())

View File

@ -0,0 +1,123 @@
import itertools
from .. import objtypes
from ..mixin import ValueType
array_supers = 'java/lang/Object','java/lang/Cloneable','java/io/Serializable'
obj_fset = frozenset([objtypes.ObjectTT])
def isAnySubtype(env, x, seq):
return any(objtypes.isSubtype(env,x,y) for y in seq)
class TypeConstraint(ValueType):
__slots__ = "env supers exact isBot".split()
def __init__(self, env, supers, exact):
self.env, self.supers, self.exact = env, frozenset(supers), frozenset(exact)
self.isBot = objtypes.ObjectTT in supers
temp = self.supers | self.exact
assert objtypes.NullTT not in temp
assert all(objtypes.isBaseTClass(tt) for tt in supers)
assert all(objtypes.dim(tt) < 999 for tt in exact)
def _key(self): return self.supers, self.exact
def __nonzero__(self): return bool(self.supers or self.exact)
def getSingleTType(self):
# comSuper doesn't care about order so we can freely pass in nondeterministic order
return objtypes.commonSupertype(self.env, list(self.supers) + list(self.exact))
def isBoolOrByteArray(self):
if self.supers or len(self.exact) != 2:
return False
tt1, tt2 = self.exact
bases = objtypes.baset(tt1), objtypes.baset(tt2)
return objtypes.dim(tt1) == objtypes.dim(tt2) and sorted(bases) == [objtypes.baset(objtypes.BoolTT), objtypes.baset(objtypes.ByteTT)]
@staticmethod
def reduce(env, supers, exact):
newsupers = []
for x in supers:
if not isAnySubtype(env, x, newsupers):
newsupers = [y for y in newsupers if not objtypes.isSubtype(env, y, x)]
newsupers.append(x)
newexact = [x for x in exact if not isAnySubtype(env, x, newsupers)]
return TypeConstraint(env, newsupers, newexact)
def join(*cons):
assert len(set(map(type, cons))) == 1
env = cons[0].env
# optimize for the common case of joining with itself or with bot
cons = set(c for c in cons if not c.isBot)
if not cons:
return TypeConstraint(env, obj_fset, [])
elif len(cons) == 1:
return cons.pop()
assert(len(cons) == 2) # joining more than 2 not currently supported
supers_l, exact_l = zip(*(c._key() for c in cons))
newsupers = set()
for t1,t2 in itertools.product(*supers_l):
if objtypes.isSubtype(env, t1, t2):
newsupers.add(t1)
elif objtypes.isSubtype(env, t2, t1):
newsupers.add(t2)
else: # TODO: need to add special handling for interfaces here
pass
newexact = frozenset.union(*exact_l)
for c in cons:
newexact = [x for x in newexact if x in c.exact or isAnySubtype(env, x, c.supers)]
return TypeConstraint.reduce(env, newsupers, newexact)
def meet(*cons):
supers = frozenset.union(*(c.supers for c in cons))
exact = frozenset.union(*(c.exact for c in cons))
return TypeConstraint.reduce(cons[0].env, supers, exact)
class ObjectConstraint(ValueType):
__slots__ = "null types isBot".split()
def __init__(self, null, types):
self.null, self.types = null, types
self.isBot = null and types.isBot
@staticmethod
def constNull(env):
return ObjectConstraint(True, TypeConstraint(env, [], []))
@staticmethod
def fromTops(env, supers, exact, nonnull=False):
types = TypeConstraint(env, supers, exact)
if nonnull and not types:
return None
return ObjectConstraint(not nonnull, types)
def _key(self): return self.null, self.types
def isConstNull(self): return self.null and not self.types
def getSingleTType(self):
return self.types.getSingleTType() if self.types else objtypes.NullTT
def join(*cons):
null = all(c.null for c in cons)
types = TypeConstraint.join(*(c.types for c in cons))
if not null and not types:
return None
res = ObjectConstraint(null, types)
return cons[0] if cons[0] == res else res
def meet(*cons):
null = any(c.null for c in cons)
types = TypeConstraint.meet(*(c.types for c in cons))
return ObjectConstraint(null, types)
def __str__(self): # pragma: no cover
if not self.types:
return 'Obj(null)'
return 'Obj({}, {}, {})'.format(self.null, sorted(self.types.supers), sorted(self.types.exact))
__repr__ = __str__

View File

@ -0,0 +1,207 @@
import collections
import itertools
from . import objtypes
from .mixin import ValueType
class CatchSetManager(object):
def __init__(self, env, sets, mask):
self.env, self.sets, self.mask = env, sets, mask
assert not self._conscheck()
@staticmethod # factory
def new(env, chpairs):
sets = collections.OrderedDict() # make this ordered since OnException relies on it
sofar = empty = ExceptionSet.EMPTY
for catchtype, handler in chpairs:
old = sets.get(handler, empty)
new = ExceptionSet.fromTops(env, catchtype)
sets[handler] = old | (new - sofar)
sofar = sofar | new
return CatchSetManager(env, sets, sofar)
def newMask(self, mask):
for k in self.sets:
self.sets[k] &= mask
self.mask &= mask
assert not self._conscheck()
def pruneKeys(self):
for handler, catchset in list(self.sets.items()):
if not catchset:
del self.sets[handler]
def copy(self):
return CatchSetManager(self.env, self.sets.copy(), self.mask)
def replaceKeys(self, replace):
self.sets = collections.OrderedDict((replace.get(key,key), val) for key, val in self.sets.items())
def _conscheck(self):
temp = ExceptionSet.EMPTY
for v in self.sets.values():
assert not v & temp
temp |= v
assert temp == self.mask
assert isinstance(self.sets, collections.OrderedDict)
class ExceptionSet(ValueType):
__slots__ = "env pairs".split()
def __init__(self, env, pairs): # assumes arguments are in reduced form
self.env = env
self.pairs = frozenset([(x,frozenset(y)) for x,y in pairs])
# We allow env to be None for the empty set so we can construct empty sets easily
# Any operation resulting in a nonempty set will get its env from the nonempty argument
assert self.empty() or self.env is not None
# make sure set is fully reduced
parts = []
for t, holes in pairs:
parts.append(t)
parts.extend(holes)
assert len(set(parts)) == len(parts)
@staticmethod # factory
def fromTops(env, *tops):
return ExceptionSet(env, [(x, frozenset()) for x in tops])
def _key(self): return self.pairs
def empty(self): return not self.pairs
def __nonzero__(self): return bool(self.pairs)
def getTopTTs(self): return sorted([objtypes.TypeTT(top,0) for (top,holes) in self.pairs])
def __sub__(self, other):
assert type(self) == type(other)
if self.empty() or other.empty():
return self
if self == other:
return ExceptionSet.EMPTY
subtest = self.env.isSubclass
pairs = self.pairs
for pair2 in other.pairs:
# Warning, due to a bug in Python, TypeErrors raised inside the gen expr will give an incorect error message
# TypeError: type object argument after * must be a sequence, not generator
# This can be worked around by using a list comprehension instead of a genexpr after the *
pairs = itertools.chain(*[ExceptionSet.diffPair(subtest, pair1, pair2) for pair1 in pairs])
return ExceptionSet.reduce(self.env, pairs)
def __or__(self, other):
assert type(self) == type(other)
if other.empty() or self == other:
return self
if self.empty():
return other
return ExceptionSet.reduce(self.env, self.pairs | other.pairs)
def __and__(self, other):
assert type(self) == type(other)
new = self - (self - other)
return new
def isdisjoint(self, other):
return (self-other) == self
def __str__(self): # pragma: no cover
parts = [('{} - [{}]'.format(top, ', '.join(sorted(holes))) if holes else top) for top, holes in self.pairs]
return 'ES[{}]'.format(', '.join(parts))
__repr__ = __str__
@staticmethod
def diffPair(subtest, pair1, pair2): # subtract pair2 from pair1. Returns a list of new pairs
# todo - find way to make this less ugly
t1, holes1 = pair1
t2, holes2 = pair2
if subtest(t1,t2): # t2 >= t1
if any(subtest(t1, h) for h in holes2):
return pair1,
else:
newpairs = []
holes2 = [h for h in holes2 if subtest(h, t1) and not any(subtest(h,h2) for h2 in holes1)]
for h in holes2:
newholes = [h2 for h2 in holes1 if subtest(h2, h)]
newpairs.append((h, newholes))
return newpairs
elif subtest(t2,t1): # t2 < t1
if any(subtest(t2, h) for h in holes1):
return pair1,
else:
newpairs = [(t1,ExceptionSet.reduceHoles(subtest, list(holes1)+[t2]))]
holes2 = [h for h in holes2 if not any(subtest(h,h2) for h2 in holes1)]
for h in holes2:
newholes = [h2 for h2 in holes1 if subtest(h2, h)]
newpairs.append((h, newholes))
return newpairs
else:
return pair1,
@staticmethod
def mergePair(subtest, pair1, pair2): # merge pair2 into pair1 and return the union
t1, holes1 = pair1
t2, holes2 = pair2
assert subtest(t2,t1)
if t2 in holes1:
holes1 = list(holes1)
holes1.remove(t2)
return t1, holes1 + list(holes2)
# TODO - this can probably be made more efficient
holes1a = set(h for h in holes1 if not subtest(h, t2))
holes1b = [h for h in holes1 if h not in holes1a]
merged_holes = set()
for h1, h2 in itertools.product(holes1b, holes2):
if subtest(h2, h1):
merged_holes.add(h1)
elif subtest(h1, h2):
merged_holes.add(h2)
merged_holes = ExceptionSet.reduceHoles(subtest, merged_holes)
assert len(merged_holes) <= len(holes1b) + len(holes2)
return t1, (list(holes1a) + merged_holes)
@staticmethod
def reduceHoles(subtest, holes):
newholes = []
for hole in holes:
for ehole in newholes:
if subtest(hole, ehole):
break
else:
newholes = [hole] + [h for h in newholes if not subtest(h, hole)]
return newholes
@staticmethod
def reduce(env, pairs):
subtest = env.isSubclass
pairs = [pair for pair in pairs if pair[0] not in pair[1]] # remove all degenerate pairs
newpairs = []
while pairs:
top, holes = pair = pairs.pop()
# look for an existing top to merge into
for epair in newpairs[:]:
etop, eholes = epair
# new pair can be merged into existing pair
if subtest(top, etop) and (top in eholes or not any(subtest(top, ehole) for ehole in eholes)):
new = ExceptionSet.mergePair(subtest, epair, pair)
newpairs, pairs = [new], [p for p in newpairs if p is not epair] + pairs
break
# existing pair can be merged into new pair
elif subtest(etop, top) and (etop in holes or not any(subtest(etop, hole) for hole in holes)):
new = ExceptionSet.mergePair(subtest, pair, epair)
newpairs, pairs = [new], [p for p in newpairs if p is not epair] + pairs
break
# pair is incomparable to all existing pairs
else:
holes = ExceptionSet.reduceHoles(subtest, holes)
newpairs.append((top,holes))
return ExceptionSet(env, newpairs)
ExceptionSet.EMPTY = ExceptionSet(None, [])

View File

@ -0,0 +1,11 @@
from . import objtypes
# common exception types
Arithmetic = objtypes.TypeTT('java/lang/ArithmeticException', 0)
ArrayOOB = objtypes.TypeTT('java/lang/ArrayIndexOutOfBoundsException', 0)
ArrayStore = objtypes.TypeTT('java/lang/ArrayStoreException', 0)
ClassCast = objtypes.TypeTT('java/lang/ClassCastException', 0)
MonState = objtypes.TypeTT('java/lang/IllegalMonitorStateException', 0)
NegArrSize = objtypes.TypeTT('java/lang/NegativeArraySizeException', 0)
NullPtr = objtypes.TypeTT('java/lang/NullPointerException', 0)
OOM = objtypes.TypeTT('java/lang/OutOfMemoryError', 0)

View File

@ -0,0 +1,9 @@
class SSAFunctionBase(object):
def __init__(self, parent, arguments):
self.parent = parent
self.params = list(arguments)
assert None not in self.params
def replaceVars(self, rdict):
self.params = [rdict.get(x,x) for x in self.params]
assert None not in self.params

View File

@ -0,0 +1,764 @@
import collections
import copy
import functools
import itertools
from .. import graph_util
from ..verifier.descriptors import parseUnboundMethodDescriptor
from . import blockmaker, constraints, objtypes, ssa_jumps, ssa_ops, subproc
from .ssa_types import BasicBlock, SSA_OBJECT, verifierToSSAType
class SSA_Variable(object):
__slots__ = 'type','origin','name','const','decltype','uninit_orig_num'
def __init__(self, type_, origin=None, name=""):
self.type = type_ # SSA_INT, SSA_OBJECT, etc.
self.origin = origin
self.name = name
self.const = None
self.decltype = None # for objects, the inferred type from the verifier if any
self.uninit_orig_num = None # if uninitialized, the bytecode offset of the new instr
# for debugging
def __str__(self): # pragma: no cover
return self.name if self.name else super(SSA_Variable, self).__str__()
def __repr__(self): # pragma: no cover
name = self.name if self.name else "@" + hex(id(self))
return "Var {}".format(name)
# This class is the main IR for bytecode level methods. It consists of a control
# flow graph (CFG) in static single assignment form (SSA). Each node in the
# graph is a BasicBlock. This consists of a list of phi statements representing
# inputs, a list of operations, and a jump statement. Exceptions are represented
# explicitly in the graph with the OnException jump. Each block also keeps track
# of the unary constraints on the variables in that block.
# Handling of subprocedures is rather annoying. Each complete subproc has an associated
# ProcInfo while jsrs and rets are represented by ProcCallOp and DummyRet respectively.
# The jsrblock has the target and fallthrough as successors, while the fallthrough has
# the jsrblock as predecessor, but not the retblock. Control flow paths where the proc
# never returns are represented by ordinary jumps from blocks in the procedure to outside
# Successful completion of the proc is represented by the fallthrough edge. The fallthrough
# block gets its variables from the jsrblock, including skip vars which don't depend on the
# proc, and variables from jsr.output which represent what would have been returned from ret
# Every proc has a reachable retblock. Jsrs with no associated ret are simply turned
# into gotos during the initial basic block creation.
class SSA_Graph(object):
entryKey = blockmaker.ENTRY_KEY
def __init__(self, code):
self.code = code
self.class_ = code.class_
self.env = self.class_.env
self.inputArgs = None
self.entryBlock = None
self.blocks = None
self.procs = None # used to store information on subprocedues (from the JSR instructions)
self.block_numberer = itertools.count(-4,-1)
def condenseBlocks(self):
assert not self.procs
old = self.blocks
# Can't do a consistency check on entry as the graph may be in an inconsistent state at this point
# Since the purpose of this function is to prune unreachable blocks from self.blocks
sccs = graph_util.tarjanSCC([self.entryBlock], lambda block:block.jump.getSuccessors())
self.blocks = list(itertools.chain.from_iterable(map(reversed, sccs[::-1])))
assert set(self.blocks) <= set(old)
if len(self.blocks) < len(old):
kept = set(self.blocks)
for block in self.blocks:
for pair in block.predecessors[:]:
if pair[0] not in kept:
block.removePredPair(pair)
return [b for b in old if b not in kept]
return []
def removeUnusedVariables(self):
assert not self.procs
roots = [x for x in self.inputArgs if x is not None]
for block in self.blocks:
roots += block.jump.params
for op in block.lines:
if op.has_side_effects:
roots += op.params
reachable = graph_util.topologicalSort(roots, lambda var:(var.origin.params if var.origin else []))
keepset = set(reachable)
assert None not in keepset
def filterOps(oldops):
newops = []
for op in oldops:
# if any of the params is being removed due to being unreachable, we can assume the whole function can be removed
keep = keepset.issuperset(op.params) and (op.has_side_effects or not keepset.isdisjoint(op.getOutputs()))
if keep:
newops.append(op)
for v in op.getOutputs():
if v and v not in keepset:
op.removeOutput(v)
else:
assert keepset.isdisjoint(op.getOutputs())
assert not op.has_side_effects
return newops
for block in self.blocks:
block.phis = filterOps(block.phis)
block.lines = filterOps(block.lines)
block.filterVarConstraints(keepset)
assert self._conscheck() is None
def mergeSingleSuccessorBlocks(self):
assert(not self.procs) # Make sure that all single jsr procs are inlined first
assert self._conscheck() is None
removed = set()
for block in self.blocks:
if block in removed:
continue
while isinstance(block.jump, ssa_jumps.Goto):
jump = block.jump
block2 = jump.getNormalSuccessors()[0]
fromkey = block, False
if block2.predecessors != [fromkey]:
break
jump2 = block2.jump
ucs = block.unaryConstraints
ucs2 = block2.unaryConstraints
replace = {phi.rval: phi.get(fromkey) for phi in block2.phis}
for var2, var in replace.items():
ucs[var] = constraints.join(ucs[var], ucs2.pop(var2))
ucs.update(ucs2)
for op in block2.lines:
op.replaceVars(replace)
block.lines += block2.lines
jump2.replaceVars(replace)
block.jump = jump2
# remember to update phis of blocks referring to old child!
for successor, t in block.jump.getSuccessorPairs():
successor.replacePredPair((block2, t), (block, t))
for phi in successor.phis:
phi.replaceVars(replace)
removed.add(block2)
self.blocks = [b for b in self.blocks if b not in removed]
assert self._conscheck() is None
def disconnectConstantVariables(self):
for block in self.blocks:
for var, uc in block.unaryConstraints.items():
if var.origin is not None:
newval = None
if var.type[0] == 'int':
if uc.min == uc.max:
newval = uc.min
elif var.type[0] == 'obj':
if uc.isConstNull():
newval = 'null'
if newval is not None:
var.origin.removeOutput(var)
var.origin = None
var.const = newval
block.phis = [phi for phi in block.phis if phi.rval is not None]
assert self._conscheck() is None
def _conscheck(self):
'''Sanity check'''
for block in self.blocks:
assert block.jump is not None
for phi in block.phis:
assert phi.rval is None or phi.rval in block.unaryConstraints
for k,v in phi.dict.items():
assert v in k[0].unaryConstraints
keys = [block.key for block in self.blocks]
assert len(set(keys)) == len(keys)
temp = [self.entryBlock]
for proc in self.procs:
temp += [proc.retblock]
temp += proc.jsrblocks
assert len(set(temp)) == len(temp)
def copyPropagation(self):
# Loop aware copy propagation
assert not self.procs
assert self._conscheck() is None
# The goal is to propagate constants that would never be inferred pessimistically
# due to the prescence of loops. Variables that aren't derived from a constant or phi
# are treated as opaque and variables are processed by SCC in topological order.
# For each scc, we can infer that it is the meet of all inputs that come from variables
# in different sccs that come before it in topological order, thus ignoring variables
# in the current scc (the loop problem).
v2b = {}
assigns = collections.OrderedDict()
for block in self.blocks:
for var in block.unaryConstraints:
v2b[var] = block
for phi in block.phis:
assigns[phi.rval] = map(phi.get, block.predecessors)
UCs = {}
sccs = graph_util.tarjanSCC(assigns, lambda v:assigns.get(v, []))
for scc in sccs:
if all(var in assigns for var in scc):
invars = sum(map(assigns.get, scc), [])
inputs = [UCs[invar] for invar in invars if invar in UCs]
assert inputs
uc = constraints.meet(*inputs)
for var in scc:
old = v2b[var].unaryConstraints[var]
new = constraints.join(uc, old) or old # temporary hack
v2b[var].unaryConstraints[var] = UCs[var] = new
else:
# There is a root in this scc, so we can't do anything
for var in scc:
UCs[var] = v2b[var].unaryConstraints[var]
assert self._conscheck() is None
def abstractInterpert(self):
# Sparse conditional constant propagation and type inference
assert not self.procs
assert self._conscheck() is None
visit_counts = collections.defaultdict(int)
dirty_phis = set(itertools.chain.from_iterable(block.phis for block in self.blocks))
while dirty_phis:
for block in self.blocks:
assert block in self.blocks
UCs = block.unaryConstraints
assert None not in UCs.values()
dirty = visit_counts[block] == 0
for phi in block.phis:
if phi in dirty_phis:
dirty_phis.remove(phi)
inputs = [key[0].unaryConstraints[phi.get(key)] for key in block.predecessors]
out = constraints.meet(*inputs)
old = UCs[phi.rval]
UCs[phi.rval] = out = constraints.join(old, out)
dirty = dirty or out != old
assert out
if not dirty or visit_counts[block] >= 5:
continue
visit_counts[block] += 1
must_throw = False
dirty_vars = set()
last_line = block.lines[-1] if block.lines else None # Keep reference handy to exception phi, if any
for i, op in enumerate(block.lines):
if hasattr(op, 'propagateConstraints'):
output_vars = op.getOutputs()
inputs = [UCs[var] for var in op.params]
assert None not in inputs
output_info = op.propagateConstraints(*inputs)
for var, out in zip(output_vars, [output_info.rval, output_info.eval]):
if var is None:
continue
old = UCs[var]
UCs[var] = out = constraints.join(old, out)
if out is None:
if var is op.outException:
assert isinstance(last_line, ssa_ops.ExceptionPhi)
last_line.params.remove(var)
op.removeOutput(var) # Note, this must be done after the op.outException check!
del UCs[var]
elif out != old:
dirty_vars.add(var)
if output_info.must_throw:
must_throw = True
# Remove all code after this in the basic block and adjust exception code
# at end as appropriate
assert isinstance(last_line, ssa_ops.ExceptionPhi)
assert i < len(block.lines) and op.outException
removed = block.lines[i+1:-1]
block.lines = block.lines[:i+1] + [last_line]
for op2 in removed:
if op2.outException:
last_line.params.remove(op2.outException)
for var in op2.getOutputs():
if var is not None:
del UCs[var]
break
# now handle end of block
if isinstance(last_line, ssa_ops.ExceptionPhi):
inputs = map(UCs.get, last_line.params)
out = constraints.meet(*inputs)
old = UCs[last_line.outException]
assert out is None or not out.null
UCs[last_line.outException] = out = constraints.join(old, out)
if out is None:
del UCs[last_line.outException]
block.lines.pop()
elif out != old:
dirty_vars.add(last_line.outException)
# prune jumps
dobreak = False
if hasattr(block.jump, 'constrainJumps'):
assert block.jump.params
oldEdges = block.jump.getSuccessorPairs()
inputs = map(UCs.get, block.jump.params)
block.jump = block.jump.constrainJumps(*inputs)
# No exception case ordinarily won't be pruned, so we have to handle it explicitly
if must_throw and isinstance(block.jump, ssa_jumps.OnException):
if block.jump.getNormalSuccessors(): # make sure it wasn't already pruned
fallthrough = block.jump.getNormalSuccessors()[0]
block.jump = block.jump.reduceSuccessors([(fallthrough, False)])
newEdges = block.jump.getSuccessorPairs()
if newEdges != oldEdges:
pruned = [x for x in oldEdges if x not in newEdges]
for (child,t) in pruned:
child.removePredPair((block,t))
removed_blocks = self.condenseBlocks()
# In case where no blocks were removed, self.blocks will possibly be in a different
# order than the version of self.blocks we are iterating over, but it still has the
# same contents, so this should be safe. If blocks were removed, we break out of the
# list and restart to avoid the possibility of processing an unreachable block.
dobreak = len(removed_blocks) > 0
for removed in removed_blocks:
for phi in removed.phis:
dirty_phis.discard(phi)
# update dirty set
for child, t in block.jump.getSuccessorPairs():
assert child in self.blocks
for phi in child.phis:
if phi.get((block, t)) in dirty_vars:
dirty_phis.add(phi)
if dobreak:
break
# Try to turn switches into if statements - note that this may
# introduce a new variable and this modify block.unaryConstraints
# However, it won't change the control flow graph structure
for block in self.blocks:
if isinstance(block.jump, ssa_jumps.Switch):
block.jump = block.jump.simplifyToIf(block)
def simplifyThrows(self):
# Try to turn throws into gotos where possible. This primarily helps with certain patterns of try-with-resources
# To do this, the exception must be known to be non null and there must be only one target that can catch it
# As a heuristic, we also restrict it to cases where every predecessor of the target can be converted
candidates = collections.defaultdict(list)
for block in self.blocks:
if not isinstance(block.jump, ssa_jumps.OnException) or len(block.jump.getSuccessorPairs()) != 1:
continue
if len(block.lines[-1].params) != 1 or not isinstance(block.lines[-2], ssa_ops.Throw):
continue
if block.unaryConstraints[block.lines[-2].params[0]].null:
continue
candidates[block.jump.getExceptSuccessors()[0]].append(block)
for child in self.blocks:
if not candidates[child] or len(candidates[child]) < len(child.predecessors):
continue
for parent in candidates[child]:
ephi = parent.lines.pop()
throw_op = parent.lines.pop()
var1 = throw_op.params[0]
var2 = throw_op.outException
assert ephi.params == [var2]
var3 = ephi.outException
assert parent.jump.params[0] == var3
for phi in child.phis:
phi.replaceVars({var3: var1})
child.replacePredPair((parent, True), (parent, False))
del parent.unaryConstraints[var2]
del parent.unaryConstraints[var3]
parent.jump = ssa_jumps.Goto(self, child)
def simplifyCatchIgnored(self):
# When there is a single throwing instruction, which is garuenteed to throw, has a single handler, and
# the caught exception is unused, turn it into a goto. This simplifies a pattern used by some obfuscators
# that do stuff like try{new int[-1];} catch(Exception e) {...}
candidates = collections.defaultdict(list)
for block in self.blocks:
if not isinstance(block.jump, ssa_jumps.OnException) or len(block.jump.getSuccessorPairs()) != 1:
continue
if len(block.lines[-1].params) != 1:
continue
candidates[block.jump.getExceptSuccessors()[0]].append(block)
for child in self.blocks:
if not candidates[child] or len(candidates[child]) < len(child.predecessors):
continue
# Make sure caught exception is unused
temp = candidates[child][0].lines[-1].outException
if any(temp in phi.params for phi in child.phis):
continue
for parent in candidates[child]:
ephi = parent.lines.pop()
throw_op = parent.lines.pop()
del parent.unaryConstraints[throw_op.outException]
del parent.unaryConstraints[ephi.outException]
child.replacePredPair((parent, True), (parent, False))
parent.jump = ssa_jumps.Goto(self, child)
# Subprocedure stuff #####################################################
def _newBlockFrom(self, block):
b = BasicBlock(next(self.block_numberer))
self.blocks.append(b)
return b
def _copyVar(self, var, vard=None):
v = copy.copy(var)
v.name = v.origin = None # TODO - generate new names?
if vard is not None:
vard[var] = v
return v
def _region(self, proc):
# Find the set of blocks 'in' a subprocedure, i.e. those reachable from the target that can reach the ret block
region = graph_util.topologicalSort([proc.retblock], lambda block:[] if block == proc.target else [b for b,t in block.predecessors])
temp = set(region)
assert self.entryBlock not in temp and proc.target in temp and temp.isdisjoint(proc.jsrblocks)
return region
def _duplicateBlocks(self, region, excludedPreds):
# Duplicate a region of blocks. All inedges will be redirected to the new blocks
# except for those from excludedPreds
excludedPreds = excludedPreds | set(region)
outsideBlocks = [b for b in self.blocks if b not in excludedPreds]
blockd, vard = {}, {}
for oldb in region:
block = blockd[oldb] = self._newBlockFrom(oldb)
block.unaryConstraints = {self._copyVar(k, vard):v for k, v in oldb.unaryConstraints.items()}
block.phis = [ssa_ops.Phi(block, vard[oldphi.rval]) for oldphi in oldb.phis]
for op in oldb.lines:
new = copy.copy(op)
new.replaceVars(vard)
new.replaceOutVars(vard)
assert new.getOutputs().count(None) == op.getOutputs().count(None)
for outv in new.getOutputs():
if outv is not None:
assert outv.origin is None
outv.origin = new
block.lines.append(new)
assert set(vard).issuperset(oldb.jump.params)
block.jump = oldb.jump.clone()
block.jump.replaceVars(vard)
# Fix up blocks outside the region that jump into the region.
for key in oldb.predecessors[:]:
pred = key[0]
if pred not in excludedPreds:
for phi1, phi2 in zip(oldb.phis, block.phis):
phi2.add(key, phi1.get(key))
del phi1.dict[key]
oldb.predecessors.remove(key)
block.predecessors.append(key)
# fix up jump targets of newly created blocks
for oldb, block in blockd.items():
block.jump.replaceBlocks(blockd)
for suc, t in block.jump.getSuccessorPairs():
suc.predecessors.append((block, t))
# update the jump targets of predecessor blocks
for block in outsideBlocks:
block.jump.replaceBlocks(blockd)
for old, new in vard.items():
assert type(old.origin) == type(new.origin)
# Fill in phi args in successors of new blocks
for oldb, block in blockd.items():
for oldc, t in oldb.jump.getSuccessorPairs():
child = blockd.get(oldc, oldc)
assert len(child.phis) == len(oldc.phis)
for phi1, phi2 in zip(oldc.phis, child.phis):
phi2.add((block, t), vard[phi1.get((oldb, t))])
assert self._conscheck() is None
return blockd
def _splitSubProc(self, proc):
# Splits a proc into two, with one callsite using the new proc instead
# this involves duplicating the body of the procedure
# the new proc is appended to the list of procs so it can work properly
# with the stack processing in inlineSubprocs
assert len(proc.jsrblocks) > 1
target, retblock = proc.target, proc.retblock
region = self._region(proc)
split_jsrs = [proc.jsrblocks.pop()]
blockd = self._duplicateBlocks(region, set(proc.jsrblocks))
newproc = subproc.ProcInfo(blockd[proc.retblock], blockd[proc.target])
newproc.jsrblocks = split_jsrs
# Sanity check
for temp in self.procs + [newproc]:
for jsr in temp.jsrblocks:
assert jsr.jump.target == temp.target
return newproc
def _inlineSubProc(self, proc):
# Inline a proc with single callsite inplace
assert len(proc.jsrblocks) == 1
target, retblock = proc.target, proc.retblock
region = self._region(proc)
jsrblock = proc.jsrblocks[0]
jsrop = jsrblock.jump
ftblock = jsrop.fallthrough
# first we find any vars that bypass the proc since we have to pass them through the new blocks
skipvars = [phi.get((jsrblock, False)) for phi in ftblock.phis]
skipvars = [var for var in skipvars if var.origin is not jsrop]
svarcopy = {(var, block):self._copyVar(var) for var, block in itertools.product(skipvars, region)}
for var, block in itertools.product(skipvars, region):
# Create a new phi for the passed through var for this block
rval = svarcopy[var, block]
phi = ssa_ops.Phi(block, rval)
block.phis.append(phi)
block.unaryConstraints[rval] = jsrblock.unaryConstraints[var]
for key in block.predecessors:
if key == (jsrblock, False):
phi.add(key, var)
else:
phi.add(key, svarcopy[var, key[0]])
outreplace = {jv:rv for jv, rv in zip(jsrblock.jump.output.stack, retblock.jump.input.stack) if jv is not None}
outreplace.update({jv:retblock.jump.input.locals[i] for i, jv in jsrblock.jump.output.locals.items() if jv is not None})
for var in outreplace: # don't need jsrop's out vars anymore
del jsrblock.unaryConstraints[var]
for var in skipvars:
outreplace[var] = svarcopy[var, retblock]
jsrblock.jump = ssa_jumps.Goto(self, target)
retblock.jump = ssa_jumps.Goto(self, ftblock)
ftblock.replacePredPair((jsrblock, False), (retblock, False))
for phi in ftblock.phis:
phi.replaceVars(outreplace)
def inlineSubprocs(self):
assert self._conscheck() is None
assert self.procs
# establish DAG of subproc callstacks if we're doing nontrivial inlining, since we can only inline leaf procs
regions = {proc:frozenset(self._region(proc)) for proc in self.procs}
parents = {proc:[] for proc in self.procs}
for x,y in itertools.product(self.procs, repeat=2):
if not regions[y].isdisjoint(x.jsrblocks):
parents[x].append(y)
self.procs = graph_util.topologicalSort(self.procs, parents.get)
if any(parents.values()):
print 'Warning, nesting subprocedures detected! This method may take a long time to decompile.'
print 'Subprocedures for', self.code.method.name + ':', self.procs
# now inline the procs
while self.procs:
proc = self.procs.pop()
while len(proc.jsrblocks) > 1:
print 'splitting', proc
# push new subproc onto stack
self.procs.append(self._splitSubProc(proc))
assert self._conscheck() is None
# When a subprocedure has only one call point, it can just be inlined instead of splitted
print 'inlining', proc
self._inlineSubProc(proc)
assert self._conscheck() is None
##########################################################################
def splitDualInedges(self):
# Split any blocks that have both normal and exceptional in edges
assert not self.procs
for block in self.blocks[:]:
if block is self.entryBlock:
continue
types = set(zip(*block.predecessors)[1])
if len(types) <= 1:
continue
assert not isinstance(block.jump, (ssa_jumps.Return, ssa_jumps.Rethrow))
new = self._newBlockFrom(block)
print 'Splitting', block, '->', new
# first fix up CFG edges
badpreds = [t for t in block.predecessors if t[1]]
new.predecessors = badpreds
for t in badpreds:
block.predecessors.remove(t)
for pred, _ in badpreds:
assert isinstance(pred.jump, ssa_jumps.OnException)
pred.jump.replaceExceptTarget(block, new)
new.jump = ssa_jumps.Goto(self, block)
block.predecessors.append((new, False))
# fix up variables
new.phis = []
new.unaryConstraints = {}
for phi in block.phis:
newrval = self._copyVar(phi.rval)
new.unaryConstraints[newrval] = block.unaryConstraints[phi.rval]
newphi = ssa_ops.Phi(new, newrval)
new.phis.append(newphi)
for t in badpreds:
arg = phi.get(t)
phi.delete(t)
newphi.add(t, arg)
phi.add((new, False), newrval)
assert self._conscheck() is None
def fixLoops(self):
assert not self.procs
todo = self.blocks[:]
while todo:
newtodo = []
temp = set(todo)
sccs = graph_util.tarjanSCC(todo, lambda block:[x for x,t in block.predecessors if x in temp])
for scc in sccs:
if len(scc) <= 1:
continue
scc_pair_set = {(x, False) for x in scc} | {(x, True) for x in scc}
entries = [n for n in scc if not scc_pair_set.issuperset(n.predecessors)]
if len(entries) <= 1:
head = entries[0]
else:
# if more than one entry point into the loop, we have to choose one as the head and duplicate the rest
print 'Warning, multiple entry point loop detected. Generated code may be extremely large',
print '({} entry points, {} blocks)'.format(len(entries), len(scc))
def loopSuccessors(head, block):
if block == head:
return []
return [x for x in block.jump.getSuccessors() if (x, False) in scc_pair_set]
reaches = [(n, graph_util.topologicalSort(entries, functools.partial(loopSuccessors, n))) for n in scc]
for head, reachable in reaches:
reachable.remove(head)
head, reachable = min(reaches, key=lambda t:(len(t[1]), -len(t[0].predecessors)))
assert head not in reachable
print 'Duplicating {} nodes'.format(len(reachable))
blockd = self._duplicateBlocks(reachable, set(scc) - set(reachable))
newtodo += map(blockd.get, reachable)
newtodo.extend(scc)
newtodo.remove(head)
todo = newtodo
assert self._conscheck() is None
# Functions called by children ###########################################
# assign variable names for debugging
varnum = collections.defaultdict(itertools.count)
def makeVariable(self, *args, **kwargs):
# Note: Make sure this doesn't hold on to created variables in any way,
# since this func may be called for temporary results that are discarded
var = SSA_Variable(*args, **kwargs)
# pref = args[0][0][0].replace('o','a')
# var.name = pref + str(next(self.varnum[pref]))
return var
def setObjVarData(self, var, vtype, initMap):
vtype2 = initMap.get(vtype, vtype)
tt = objtypes.verifierToSynthetic(vtype2)
assert var.decltype is None or var.decltype == tt
var.decltype = tt
# if uninitialized, record the offset of originating new instruction for later
if vtype.tag == '.new':
assert var.uninit_orig_num is None or var.uninit_orig_num == vtype.extra
var.uninit_orig_num = vtype.extra
def makeVarFromVtype(self, vtype, initMap):
vtype2 = initMap.get(vtype, vtype)
type_ = verifierToSSAType(vtype2)
if type_ is not None:
var = self.makeVariable(type_)
if type_ == SSA_OBJECT:
self.setObjVarData(var, vtype, initMap)
return var
return None
def getConstPoolArgs(self, index):
return self.class_.cpool.getArgs(index)
def getConstPoolType(self, index):
return self.class_.cpool.getType(index)
def ssaFromVerified(code, iNodes, opts):
method = code.method
inputTypes, returnTypes = parseUnboundMethodDescriptor(method.descriptor, method.class_.name, method.static)
parent = SSA_Graph(code)
data = blockmaker.BlockMaker(parent, iNodes, inputTypes, returnTypes, code.except_raw, opts=opts)
parent.blocks = blocks = data.blocks
parent.entryBlock = data.entryBlock
parent.inputArgs = data.inputArgs
assert parent.entryBlock in blocks
# create subproc info
procd = {block.jump.target: subproc.ProcInfo(block, block.jump.target) for block in blocks if isinstance(block.jump, subproc.DummyRet)}
for block in blocks:
if isinstance(block.jump, subproc.ProcCallOp):
procd[block.jump.target].jsrblocks.append(block)
parent.procs = sorted(procd.values(), key=lambda p:p.target.key)
# Intern constraints to save a bit of memory for long methods
def makeConstraint(var, _cache={}):
key = var.type, var.const, var.decltype, var.uninit_orig_num is None
try:
return _cache[key]
except KeyError:
_cache[key] = temp = constraints.fromVariable(parent.env, var)
return temp
# create unary constraints for each variable
for block in blocks:
bvars = []
if isinstance(block.jump, subproc.ProcCallOp):
bvars += block.jump.flatOutput()
# entry block has no phis
if block is parent.entryBlock:
bvars += parent.inputArgs
bvars = [v for v in bvars if v is not None]
bvars += [phi.rval for phi in block.phis]
for op in block.lines:
bvars += op.params
bvars += [x for x in op.getOutputs() if x is not None]
bvars += block.jump.params
for suc, t in block.jump.getSuccessorPairs():
for phi in suc.phis:
bvars.append(phi.get((block, t)))
assert None not in bvars
# Note that makeConstraint can indirectly cause class loading
block.unaryConstraints = {var:makeConstraint(var) for var in bvars}
parent._conscheck()
return parent

View File

@ -0,0 +1,6 @@
class ValueType(object):
'''Define _key() and inherit from this class to implement comparison and hashing'''
# def __init__(self, *args, **kwargs): super(ValueType, self).__init__(*args, **kwargs)
def __eq__(self, other): return type(self) == type(other) and self._key() == other._key()
def __ne__(self, other): return type(self) != type(other) or self._key() != other._key()
def __hash__(self): return hash(self._key())

View File

@ -0,0 +1,126 @@
from ..verifier import verifier_types as vtypes
# types are represented by classname, dimension
# primitive types are .int, etc since these cannot be valid classnames since periods are forbidden
def TypeTT(baset, dim):
assert dim >= 0
return baset, dim
# Not real types
VoidTT = TypeTT('.void', 0)
NullTT = TypeTT('.null', 0)
ObjectTT = TypeTT('java/lang/Object', 0)
StringTT = TypeTT('java/lang/String', 0)
ThrowableTT = TypeTT('java/lang/Throwable', 0)
ClassTT = TypeTT('java/lang/Class', 0)
BoolTT = TypeTT('.boolean', 0)
IntTT = TypeTT('.int', 0)
LongTT = TypeTT('.long', 0)
FloatTT = TypeTT('.float', 0)
DoubleTT = TypeTT('.double', 0)
ByteTT = TypeTT('.byte', 0)
CharTT = TypeTT('.char', 0)
ShortTT = TypeTT('.short', 0)
BExpr = '.bexpr' # bool or byte
def baset(tt): return tt[0]
def dim(tt): return tt[1]
def withDimInc(tt, inc): return TypeTT(baset(tt), dim(tt)+inc)
def withNoDim(tt): return TypeTT(baset(tt), 0)
def isBaseTClass(tt): return not baset(tt).startswith('.')
def className(tt): return baset(tt) if not baset(tt).startswith('.') else None
def primName(tt): return baset(tt)[1:] if baset(tt).startswith('.') else None
###############################################################################
def isSubtype(env, x, y):
if x == y or y == ObjectTT or x == NullTT:
return True
elif y == NullTT:
return False
xname, xdim = baset(x), dim(x)
yname, ydim = baset(y), dim(y)
if ydim > xdim:
return False
elif xdim > ydim: # TODO - these constants should be defined in one place to reduce risk of typos
return yname in ('java/lang/Object','java/lang/Cloneable','java/io/Serializable')
else:
return isBaseTClass(x) and isBaseTClass(y) and env.isSubclass(xname, yname)
# Will not return interface unless all inputs are same interface or null
def commonSupertype(env, tts):
assert(hasattr(env, 'getClass')) # catch common errors where we forget the env argument
tts = set(tts)
tts.discard(NullTT)
if len(tts) == 1:
return tts.pop()
elif not tts:
return NullTT
dims = map(dim, tts)
newdim = min(dims)
if max(dims) > newdim or any(baset(tt) == 'java/lang/Object' for tt in tts):
return TypeTT('java/lang/Object', newdim)
# if any are primitive arrays, result is object array of dim-1
if not all(isBaseTClass(tt) for tt in tts):
return TypeTT('java/lang/Object', newdim-1)
# find common superclass of base types
bases = sorted(map(baset, tts))
superclass = reduce(env.commonSuperclass, bases)
return TypeTT(superclass, newdim)
######################################################################################################
_verifierConvert = {vtypes.T_INT:IntTT, vtypes.T_FLOAT:FloatTT, vtypes.T_LONG:LongTT,
vtypes.T_DOUBLE:DoubleTT, vtypes.T_SHORT:ShortTT, vtypes.T_CHAR:CharTT,
vtypes.T_BYTE:ByteTT, vtypes.T_BOOL:BoolTT, vtypes.T_NULL:NullTT,
vtypes.OBJECT_INFO:ObjectTT}
def verifierToSynthetic_seq(vts):
return [verifierToSynthetic(vt) for vt in vts if vt != vtypes.T_INVALID]
def verifierToSynthetic(vtype):
assert vtype.tag not in (None, '.address', '.new', '.init')
vtype = vtypes.withNoConst(vtype)
if vtype in _verifierConvert:
return _verifierConvert[vtype]
base = vtypes.withNoDimension(vtype)
if base in _verifierConvert:
return withDimInc(_verifierConvert[base], vtype.dim)
return TypeTT(vtype.extra, vtype.dim)
# returns supers, exacts
def declTypeToActual(env, decltype):
name, newdim = baset(decltype), dim(decltype)
# Verifier treats bool[]s and byte[]s as interchangeable, so it could really be either
if newdim and (name == baset(ByteTT) or name == baset(BoolTT)):
return [], [withDimInc(ByteTT, newdim), withDimInc(BoolTT, newdim)]
elif not isBaseTClass(decltype): # primitive types can't be subclassed anyway
return [], [decltype]
# Verifier doesn't fully verify interfaces so they could be anything
if env.isInterface(name):
return [withDimInc(ObjectTT, newdim)], []
# If class is final, return it as exact, not super
elif env.isFinal(name):
return [], [decltype]
else:
return [decltype], []
def removeInterface(env, decltype):
name, newdim = baset(decltype), dim(decltype)
if isBaseTClass(decltype) and env.isInterface(name):
return withDimInc(ObjectTT, newdim)
return decltype

View File

@ -0,0 +1,10 @@
from .base import BaseJump
from .onexception import OnException
from .goto import Goto
from .ifcmp import If
from .exit import Return, Rethrow
from .switch import Switch
from . import placeholder
OnAbscond = Ret = placeholder.Placeholder

View File

@ -0,0 +1,18 @@
import copy
from ..functionbase import SSAFunctionBase
class BaseJump(SSAFunctionBase):
def __init__(self, parent, arguments=()):
super(BaseJump, self).__init__(parent,arguments)
def replaceBlocks(self, blockDict):
assert not self.getSuccessors()
def getNormalSuccessors(self): return []
def getExceptSuccessors(self): return []
def getSuccessors(self): return self.getNormalSuccessors() + self.getExceptSuccessors()
def getSuccessorPairs(self): return [(x,False) for x in self.getNormalSuccessors()] + [(x,True) for x in self.getExceptSuccessors()]
def reduceSuccessors(self, pairsToRemove): return self
def clone(self): return copy.copy(self) # overriden by classes which need to do a deep copy

View File

@ -0,0 +1,9 @@
from .base import BaseJump
class Return(BaseJump):
def __init__(self, parent, arguments):
super(Return, self).__init__(parent, arguments)
class Rethrow(BaseJump):
def __init__(self, parent, arguments):
super(Rethrow, self).__init__(parent, arguments)

View File

@ -0,0 +1,17 @@
from .base import BaseJump
class Goto(BaseJump):
def __init__(self, parent, target):
super(Goto, self).__init__(parent, [])
self.successors = [target]
def replaceBlocks(self, blockDict):
self.successors = [blockDict.get(key,key) for key in self.successors]
def getNormalSuccessors(self):
return self.successors
def reduceSuccessors(self, pairsToRemove):
if (self.successors[0], False) in pairsToRemove:
return None
return self

View File

@ -0,0 +1,101 @@
from .. import ssa_types
from ..constraints import IntConstraint, ObjectConstraint
from .base import BaseJump
from .goto import Goto
class If(BaseJump):
opposites = {'eq':'ne', 'ne':'eq', 'lt':'ge', 'ge':'lt', 'gt':'le', 'le':'gt'}
def __init__(self, parent, cmp, successors, arguments):
super(If, self).__init__(parent, arguments)
assert cmp in ('eq','ne','lt','ge','gt','le')
self.cmp = cmp
self.successors = successors
self.isObj = (arguments[0].type == ssa_types.SSA_OBJECT)
assert None not in successors
def replaceBlocks(self, blockDict):
self.successors = [blockDict.get(key,key) for key in self.successors]
def getNormalSuccessors(self):
return self.successors
def reduceSuccessors(self, pairsToRemove):
temp = set(self.successors)
for (child, t) in pairsToRemove:
temp.remove(child)
if len(temp) == 0:
return None
elif len(temp) == 1:
return Goto(self.parent, temp.pop())
return self
###############################################################################
def constrainJumps(self, x, y):
impossible = []
for child in self.successors:
func = self.getSuccessorConstraints((child,False))
results = func(x,y)
if None in results:
assert results == (None,None)
impossible.append((child,False))
return self.reduceSuccessors(impossible)
def getSuccessorConstraints(self, (block, t)):
assert t is False
cmp_t = If.opposites[self.cmp] if block == self.successors[0] else self.cmp
if self.isObj:
def propagateConstraints_obj(x, y):
if x is None or y is None:
return None, None
if cmp_t == 'eq':
z = x.join(y)
return z,z
else:
x2, y2 = x, y
if x.isConstNull():
yt = y.types
y2 = ObjectConstraint.fromTops(yt.env, yt.supers, yt.exact, nonnull=True)
if y.isConstNull():
xt = x.types
x2 = ObjectConstraint.fromTops(xt.env, xt.supers, xt.exact, nonnull=True)
return x2, y2
return propagateConstraints_obj
else:
def propagateConstraints_int(x, y):
if x is None or y is None:
return None, None
x1, x2, y1, y2 = x.min, x.max, y.min, y.max
if cmp_t == 'ge' or cmp_t == 'gt':
x1, x2, y1, y2 = y1, y2, x1, x2
# treat greater like less than swap before and afterwards
if cmp_t == 'lt' or cmp_t == 'gt':
x2 = min(x2, y2-1)
y1 = max(x1+1, y1)
elif cmp_t == 'le' or cmp_t == 'ge':
x2 = min(x2, y2)
y1 = max(x1, y1)
elif cmp_t == 'eq':
x1 = y1 = max(x1, y1)
x2 = y2 = min(x2, y2)
elif cmp_t == 'ne':
if x1 == x2 == y1 == y2:
return None, None
if x1 == x2:
y1 = y1 if y1 != x1 else y1+1
y2 = y2 if y2 != x2 else y2-1
if y1 == y2:
x1 = x1 if x1 != y1 else x1+1
x2 = x2 if x2 != y2 else x2-1
if cmp_t == 'ge' or cmp_t == 'gt':
x1, x2, y1, y2 = y1, y2, x1, x2
con1 = IntConstraint.range(x.width, x1, x2) if x1 <= x2 else None
con2 = IntConstraint.range(y.width, y1, y2) if y1 <= y2 else None
return con1, con2
return propagateConstraints_int

View File

@ -0,0 +1,59 @@
from .. import objtypes
from ..constraints import ObjectConstraint
from ..exceptionset import CatchSetManager, ExceptionSet
from .base import BaseJump
from .goto import Goto
class OnException(BaseJump):
def __init__(self, parent, throwvar, chpairs, fallthrough=None):
super(OnException, self).__init__(parent, [throwvar])
self.default = fallthrough
self.cs = CatchSetManager.new(parent.env, chpairs)
self.cs.pruneKeys()
def replaceExceptTarget(self, old, new):
self.cs.replaceKeys({old:new})
def replaceNormalTarget(self, old, new):
self.default = new if self.default == old else self.default
def replaceBlocks(self, blockDict):
self.cs.replaceKeys(blockDict)
if self.default is not None:
self.default = blockDict.get(self.default, self.default)
def reduceSuccessors(self, pairsToRemove):
for (child, t) in pairsToRemove:
if t:
self.cs.mask -= self.cs.sets[child]
del self.cs.sets[child]
else:
self.replaceNormalTarget(child, None)
self.cs.pruneKeys()
if not self.cs.sets:
if not self.default:
return None
return Goto(self.parent, self.default)
return self
def getNormalSuccessors(self):
return [self.default] if self.default is not None else []
def getExceptSuccessors(self):
return self.cs.sets.keys()
def clone(self):
new = super(OnException, self).clone()
new.cs = self.cs.copy()
return new
###############################################################################
def constrainJumps(self, x):
if x is None:
mask = ExceptionSet.EMPTY
else:
mask = ExceptionSet(x.types.env, [(objtypes.className(tt),()) for tt in x.types.supers | x.types.exact])
self.cs.newMask(mask)
return self.reduceSuccessors([])

View File

@ -0,0 +1,5 @@
from .base import BaseJump
class Placeholder(BaseJump):
def __init__(self, parent, *args, **kwargs):
super(Placeholder, self).__init__(parent)

View File

@ -0,0 +1,86 @@
import collections
from ..constraints import IntConstraint
from ..ssa_types import SSA_INT
from .base import BaseJump
from .goto import Goto
from .ifcmp import If
class Switch(BaseJump):
def __init__(self, parent, default, table, arguments):
super(Switch, self).__init__(parent, arguments)
# get ordered successors since our map will be unordered. Default is always first successor
if not table:
ordered = [default]
else:
tset = set()
ordered = [x for x in (default,) + zip(*table)[1] if not x in tset and not tset.add(x)]
self.successors = ordered
reverse = collections.defaultdict(set)
for k,v in table:
if v != default:
reverse[v].add(k)
self.reverse = {k: frozenset(v) for k, v in reverse.items()}
def getNormalSuccessors(self):
return self.successors
def replaceBlocks(self, blockDict):
self.successors = [blockDict.get(key,key) for key in self.successors]
self.reverse = {blockDict.get(k,k):v for k,v in self.reverse.items()}
def reduceSuccessors(self, pairsToRemove):
temp = list(self.successors)
for (child, t) in pairsToRemove:
temp.remove(child)
if len(temp) == 0:
return None
elif len(temp) == 1:
return Goto(self.parent, temp.pop())
if len(temp) < len(self.successors):
self.successors = temp
self.reverse = {v:self.reverse[v] for v in temp[1:]}
return self
def simplifyToIf(self, block):
# Try to replace with an if statement if possible
# e.g. switch(x) {case C: ... default: ...} -> if (x == C) {...} else {...}
if len(self.successors) == 2:
cases = self.reverse[self.successors[-1]]
if len(cases) == 1:
const = self.parent.makeVariable(SSA_INT)
const.const = min(cases)
block.unaryConstraints[const] = IntConstraint.const(32, const.const)
return If(self.parent, 'eq', self.successors, self.params + [const])
return self
###############################################################################
def constrainJumps(self, x):
impossible = []
for child in self.successors:
func = self.getSuccessorConstraints((child,False))
results = func(x)
if results[0] is None:
impossible.append((child,False))
return self.reduceSuccessors(impossible)
def getSuccessorConstraints(self, (block, t)):
if block in self.reverse:
cmin = min(self.reverse[block])
cmax = max(self.reverse[block])
def propagateConstraints(x):
if x is None:
return None,
return IntConstraint.range(x.width, max(cmin, x.min), min(cmax, x.max)),
else:
allcases = set().union(*self.reverse.values())
def propagateConstraints(x):
if x is None or (x.min == x.max and x.min in allcases):
return None,
return x,
return propagateConstraints

View File

@ -0,0 +1,16 @@
from .base import BaseOp
from .array import ArrLoad, ArrStore, ArrLength
from .checkcast import CheckCast, InstanceOf
from .convert import Convert
from .fieldaccess import FieldAccess
from .fmath import FAdd, FDiv, FMul, FRem, FSub, FNeg, FCmp
from .invoke import Invoke, InvokeDynamic
from .imath import IAdd, IDiv, IMul, IRem, ISub, IAnd, IOr, IShl, IShr, IUshr, IXor, ICmp
from .monitor import Monitor
from .new import New, NewArray, MultiNewArray
from .throw import Throw, MagicThrow
from .truncate import Truncate
from .tryreturn import TryReturn
from .phi import Phi, ExceptionPhi

View File

@ -0,0 +1,79 @@
from .. import excepttypes, objtypes
from ..constraints import FloatConstraint, IntConstraint, ObjectConstraint, maybeThrow, returnOrThrow, throw
from ..ssa_types import SSA_INT
from .base import BaseOp
def getElementTypes(env, tops):
types = [objtypes.withDimInc(tt, -1) for tt in tops]
# temporary hack
types = [objtypes.removeInterface(env, tt) for tt in types]
supers = [tt for tt in types if objtypes.isBaseTClass(tt)]
exact = [tt for tt in types if not objtypes.isBaseTClass(tt)]
return ObjectConstraint.fromTops(env, supers, exact)
class ArrLoad(BaseOp):
def __init__(self, parent, args, ssatype):
super(ArrLoad, self).__init__(parent, args, makeException=True)
self.env = parent.env
self.rval = parent.makeVariable(ssatype, origin=self)
self.ssatype = ssatype
def propagateConstraints(self, a, i):
etypes = (excepttypes.ArrayOOB,)
if a.null:
etypes += (excepttypes.NullPtr,)
if a.isConstNull():
return throw(ObjectConstraint.fromTops(self.env, [], [excepttypes.NullPtr], nonnull=True))
if self.ssatype[0] == 'int':
rout = IntConstraint.bot(self.ssatype[1])
elif self.ssatype[0] == 'float':
rout = FloatConstraint.bot(self.ssatype[1])
elif self.ssatype[0] == 'obj':
rout = getElementTypes(self.env, a.types.supers | a.types.exact)
eout = ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True)
return returnOrThrow(rout, eout)
class ArrStore(BaseOp):
has_side_effects = True
def __init__(self, parent, args):
super(ArrStore, self).__init__(parent, args, makeException=True)
self.env = parent.env
def propagateConstraints(self, a, i, x):
etypes = (excepttypes.ArrayOOB,)
if a.null:
etypes += (excepttypes.NullPtr,)
if a.isConstNull():
return throw(ObjectConstraint.fromTops(self.env, [], [excepttypes.NullPtr], nonnull=True))
if isinstance(x, ObjectConstraint):
# If the type of a is known exactly to be the single possibility T[]
# and x is assignable to T, we can assume there is no ArrayStore exception
# if a's type has multiple possibilities, then there can be an exception
known_type = a.types.exact if len(a.types.exact) == 1 else frozenset()
allowed = getElementTypes(self.env, known_type)
if allowed.meet(x) != allowed:
etypes += (excepttypes.ArrayStore,)
return maybeThrow(ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True))
class ArrLength(BaseOp):
def __init__(self, parent, args):
super(ArrLength, self).__init__(parent, args, makeException=True)
self.env = parent.env
self.rval = parent.makeVariable(SSA_INT, origin=self)
def propagateConstraints(self, x):
etypes = ()
if x.null:
etypes += (excepttypes.NullPtr,)
if x.isConstNull():
return throw(ObjectConstraint.fromTops(self.env, [], [excepttypes.NullPtr], nonnull=True))
excons = ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True)
return returnOrThrow(IntConstraint.range(32, 0, (1<<31)-1), excons)

View File

@ -0,0 +1,30 @@
from ..functionbase import SSAFunctionBase
from ..ssa_types import SSA_OBJECT
class BaseOp(SSAFunctionBase):
has_side_effects = False
def __init__(self, parent, arguments, makeException=False):
super(BaseOp, self).__init__(parent, arguments)
self.rval = None
self.outException = None
if makeException:
self.outException = parent.makeVariable(SSA_OBJECT, origin=self)
def getOutputs(self):
return self.rval, self.outException
def removeOutput(self, var):
outs = self.rval, self.outException
assert var is not None and var in outs
self.rval, self.outException = [(x if x != var else None) for x in outs]
def replaceOutVars(self, vardict):
self.rval, self.outException = map(vardict.get, (self.rval, self.outException))
# Given input constraints, return constraints on outputs. Output is (rval, exception)
# With None returned for unused or impossible values. This should only be defined if it is
# actually implemented.
# def propagateConstraints(self, *cons):

View File

@ -0,0 +1,54 @@
import itertools
import operator
from ..constraints import IntConstraint
def split_pow2ranges(x,y):
'''split given range into power of two ranges of form [x, x+2^k)'''
out = []
while x<=y:
# The largest power of two range of the form x,k
# has k min of number of zeros at end of x
# and the largest power of two that fits in y-x
bx = bin(x)
numzeroes = float('inf') if x==0 else (len(bx)-bx.rindex('1')-1)
k = min(numzeroes, (y-x+1).bit_length()-1)
out.append((x,k))
x += 1<<k
assert x == y+1
return out
def propagateBitwise(arg1, arg2, op, usemin, usemax):
ranges1 = split_pow2ranges(arg1.min, arg1.max)
ranges2 = split_pow2ranges(arg2.min, arg2.max)
vals = []
for (s1,k1),(s2,k2) in itertools.product(ranges1, ranges2):
# there are three parts. The high bits fixed in both arguments,
# the middle bits fixed in one but not the other, and the
# lowest bits which can be chosen freely for both arguments
# high = op(h1,h2) and low goes from 0 to 1... but the range of
# the middle depends on the particular operation
# 0-x, x-1 and 0-1 for and, or, and xor respectively
if k1 > k2:
(s1,k1),(s2,k2) = (s2,k2),(s1,k1)
mask1 = (1<<k1) - 1
mask2 = (1<<k2) - 1 - mask1
high = op(s1, s2) & ~(mask1 | mask2)
midmin = (s1 & mask2) if usemin else 0
midmax = (s1 & mask2) if usemax else mask2
vals.append(high | midmin)
vals.append(high | midmax | mask1)
return IntConstraint.range(arg1.width, min(vals), max(vals))
def propagateAnd(x, y):
return propagateBitwise(x, y, operator.__and__, False, True)
def propagateOr(x, y):
return propagateBitwise(x, y, operator.__or__, True, False)
def propagateXor( x, y):
return propagateBitwise(x, y, operator.__xor__, False, False)

View File

@ -0,0 +1,40 @@
from .. import excepttypes, objtypes, ssa_types
from ..constraints import IntConstraint, ObjectConstraint, join, maybeThrow, return_, throw
from .base import BaseOp
class CheckCast(BaseOp):
def __init__(self, parent, target, args):
super(CheckCast, self).__init__(parent, args, makeException=True)
self.env = parent.env
self.target_tt = target
# Temporary hack
target = objtypes.removeInterface(self.env, target)
if objtypes.isBaseTClass(target):
self.outCasted = ObjectConstraint.fromTops(parent.env, [target], [])
else:
# Primative array types need to be in exact, not supers
self.outCasted = ObjectConstraint.fromTops(parent.env, [], [target])
self.outExceptionCons = ObjectConstraint.fromTops(parent.env, [], (excepttypes.ClassCast,), nonnull=True)
def propagateConstraints(self, x):
intersect = join(x, self.outCasted)
if intersect is None:
return throw(self.outExceptionCons)
elif intersect != x:
assert not x.isConstNull()
return maybeThrow(self.outExceptionCons)
else:
return return_(None)
class InstanceOf(BaseOp):
def __init__(self, parent, target, args):
super(InstanceOf, self).__init__(parent, args)
self.env = parent.env
self.target_tt = target
self.rval = parent.makeVariable(ssa_types.SSA_INT, origin=self)
def propagateConstraints(self, x):
rvalcons = IntConstraint.range(32, 0, 1)
return return_(rvalcons)

View File

@ -0,0 +1,8 @@
from .base import BaseOp
class Convert(BaseOp):
def __init__(self, parent, arg, source_ssa, target_ssa):
super(Convert, self).__init__(parent, [arg])
self.source = source_ssa
self.target = target_ssa
self.rval = parent.makeVariable(target_ssa, origin=self)

View File

@ -0,0 +1,62 @@
from ...verifier.descriptors import parseFieldDescriptor
from .. import constraints, excepttypes, objtypes
from ..constraints import IntConstraint, ObjectConstraint, returnOrThrow, throw
from ..ssa_types import SSA_INT, SSA_OBJECT, verifierToSSAType
from .base import BaseOp
# Empirically, Hotspot does enfore size restrictions on short fields
# Except that bool is still a byte
_short_constraints = {
objtypes.ByteTT: IntConstraint.range(32, -128, 127),
objtypes.CharTT: IntConstraint.range(32, 0, 65535),
objtypes.ShortTT: IntConstraint.range(32, -32768, 32767),
objtypes.IntTT: IntConstraint.bot(32)
}
_short_constraints[objtypes.BoolTT] = _short_constraints[objtypes.ByteTT]
# Assume no linkage errors occur, so only exception that can be thrown is NPE
class FieldAccess(BaseOp):
has_side_effects = True
def __init__(self, parent, instr, info, args):
super(FieldAccess, self).__init__(parent, args, makeException=('field' in instr[0]))
self.instruction = instr
self.target, self.name, self.desc = info
dtype = None
if 'get' in instr[0]:
vtypes = parseFieldDescriptor(self.desc)
stype = verifierToSSAType(vtypes[0])
dtype = objtypes.verifierToSynthetic(vtypes[0]) # todo, find way to merge this with Invoke code?
cat = len(vtypes)
self.rval = parent.makeVariable(stype, origin=self)
self.returned = [self.rval] + [None]*(cat-1)
else:
self.returned = []
# just use a fixed constraint until we can do interprocedural analysis
# output order is rval, exception, defined by BaseOp.getOutputs
env = parent.env
self.eout = ObjectConstraint.fromTops(env, [excepttypes.NullPtr], [], nonnull=True)
if self.rval is not None:
if self.rval.type == SSA_OBJECT:
supers, exact = objtypes.declTypeToActual(env, dtype)
self.rout = ObjectConstraint.fromTops(env, supers, exact)
elif self.rval.type == SSA_INT:
self.rout = _short_constraints[dtype]
else:
self.rout = constraints.fromVariable(env, self.rval)
else:
self.rout = None
def propagateConstraints(self, *incons):
eout = None # no NPE
if 'field' in self.instruction[0] and incons[0].null:
eout = self.eout
if incons[0].isConstNull():
return throw(eout)
return returnOrThrow(self.rout, eout)

View File

@ -0,0 +1,40 @@
from .. import ssa_types
from ..constraints import IntConstraint, return_
from .base import BaseOp
class FAdd(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
class FDiv(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
class FMul(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
class FRem(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
class FSub(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
# Unary, unlike the others
class FNeg(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
class FCmp(BaseOp):
def __init__(self, parent, args, NaN_val):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(ssa_types.SSA_INT, origin=self)
self.NaN_val = NaN_val
def propagateConstraints(self, x, y):
return return_(IntConstraint.range(32, -1, 1))

View File

@ -0,0 +1,208 @@
import itertools
from .. import excepttypes, ssa_types
from ..constraints import IntConstraint, ObjectConstraint, returnOrThrow, return_, throw
from . import bitwise_util
from .base import BaseOp
def getNewRange(w, zmin, zmax):
HN = 1 << w-1
zmin = zmin + HN
zmax = zmax + HN
split = (zmin>>w != zmax>>w)
if split:
return return_(IntConstraint.range(w, -HN, HN-1))
else:
N = 1<<w
return return_(IntConstraint.range(w, (zmin % N)-HN, (zmax % N)-HN))
class IAdd(BaseOp):
def __init__(self, parent, args):
super(IAdd, self).__init__(parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
return getNewRange(x.width, x.min+y.min, x.max+y.max)
class IMul(BaseOp):
def __init__(self, parent, args):
super(IMul, self).__init__(parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
vals = x.min*y.min, x.min*y.max, x.max*y.min, x.max*y.max
return getNewRange(x.width, min(vals), max(vals))
class ISub(BaseOp):
def __init__(self, parent, args):
super(ISub, self).__init__(parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
return getNewRange(x.width, x.min-y.max, x.max-y.min)
#############################################################################################
class IAnd(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
return return_(bitwise_util.propagateAnd(x,y))
class IOr(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
return return_(bitwise_util.propagateOr(x,y))
class IXor(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
return return_(bitwise_util.propagateXor(x,y))
#############################################################################################
# Shifts currently only propogate ranges in the case where the shift is a known constant
# TODO - make this handle the general case
def getMaskedRange(x, bits):
assert bits < x.width
y = IntConstraint.const(x.width, (1<<bits) - 1)
x = bitwise_util.propagateAnd(x,y)
H = 1<<(bits-1)
M = 1<<bits
m1 = x.min if (x.max <= H-1) else -H
m2 = x.max if (x.min >= -H) else H-1
return m1, m2
class IShl(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
if y.min < y.max:
return return_(IntConstraint.bot(x.width))
shift = y.min % x.width
if not shift:
return return_(x)
m1, m2 = getMaskedRange(x, x.width - shift)
return return_(IntConstraint.range(x.width, m1<<shift, m2<<shift))
class IShr(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
if y.min < y.max:
return return_(IntConstraint.range(x.width, min(x.min, 0), max(x.max, 0)))
shift = y.min % x.width
if not shift:
return return_(x)
m1, m2 = x.min, x.max
return return_(IntConstraint.range(x.width, m1>>shift, m2>>shift))
class IUshr(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(args[0].type, origin=self)
def propagateConstraints(self, x, y):
M = 1<<x.width
if y.min < y.max:
intmax = (M//2)-1
return return_(IntConstraint.range(x.width, min(x.min, 0), max(x.max, intmax)))
shift = y.min % x.width
if not shift:
return return_(x)
parts = [x.min, x.max]
if x.min <= -1 <= x.max:
parts.append(-1)
if x.min <= 0 <= x.max:
parts.append(0)
parts = [p % M for p in parts]
m1, m2 = min(parts), max(parts)
return return_(IntConstraint.range(x.width, m1>>shift, m2>>shift))
#############################################################################################
exec_tts = excepttypes.Arithmetic,
class IDiv(BaseOp):
def __init__(self, parent, args):
super(IDiv, self).__init__(parent, args, makeException=True)
self.rval = parent.makeVariable(args[0].type, origin=self)
self.outExceptionCons = ObjectConstraint.fromTops(parent.env, [], exec_tts, nonnull=True)
def propagateConstraints(self, x, y):
excons = self.outExceptionCons if (y.min <= 0 <= y.max) else None
if y.min == 0 == y.max:
return throw(excons)
# Calculate possible extremes for division, taking into account special case of intmin/-1
intmin = -1<<(x.width - 1)
xvals = set([x.min, x.max])
yvals = set([y.min, y.max])
for val in (intmin+1, 0):
if x.min <= val <= x.max:
xvals.add(val)
for val in (-2,-1,1):
if y.min <= val <= y.max:
yvals.add(val)
yvals.discard(0)
vals = set()
for xv, yv in itertools.product(xvals, yvals):
if xv == intmin and yv == -1:
vals.add(intmin)
elif xv*yv < 0: # Unlike Python, Java rounds to 0 so opposite sign case must be handled specially
vals.add(-(-xv//yv))
else:
vals.add(xv//yv)
rvalcons = IntConstraint.range(x.width, min(vals), max(vals))
return returnOrThrow(rvalcons, excons)
class IRem(BaseOp):
def __init__(self, parent, args):
super(IRem, self).__init__(parent, args, makeException=True)
self.rval = parent.makeVariable(args[0].type, origin=self)
self.outExceptionCons = ObjectConstraint.fromTops(parent.env, [], exec_tts, nonnull=True)
def propagateConstraints(self, x, y):
excons = self.outExceptionCons if (y.min <= 0 <= y.max) else None
if y.min == 0 == y.max:
return throw(excons)
# only do an exact result if both values are constants, and otherwise
# just approximate the range as -(y-1) to (y-1) (or 0 to y-1 if it's positive)
if x.min == x.max and y.min == y.max:
val = abs(x.min) % abs(y.min)
val = val if x.min >= 0 else -val
return return_(IntConstraint.range(x.width, val, val))
mag = max(abs(y.min), abs(y.max)) - 1
rmin = -min(mag, abs(x.min)) if x.min < 0 else 0
rmax = min(mag, abs(x.max)) if x.max > 0 else 0
rvalcons = IntConstraint.range(x.width, rmin, rmax)
return returnOrThrow(rvalcons, excons)
###############################################################################
class ICmp(BaseOp):
def __init__(self, parent, args):
BaseOp.__init__(self, parent, args)
self.rval = parent.makeVariable(ssa_types.SSA_INT, origin=self)
def propagateConstraints(self, x, y):
rvalcons = IntConstraint.range(32, -1, 1)
return return_(rvalcons)

View File

@ -0,0 +1,87 @@
from ...verifier.descriptors import parseMethodDescriptor
from .. import constraints, excepttypes, objtypes
from ..constraints import ObjectConstraint, returnOrThrow, throw
from ..ssa_types import SSA_OBJECT, verifierToSSAType
from .base import BaseOp
class Invoke(BaseOp):
has_side_effects = True
def __init__(self, parent, instr, info, args, isThisCtor, target_tt):
super(Invoke, self).__init__(parent, args, makeException=True)
self.instruction = instr
self.target, self.name, self.desc = info
self.isThisCtor = isThisCtor # whether this is a ctor call for the current class
self.target_tt = target_tt
vtypes = parseMethodDescriptor(self.desc)[1]
dtype = None
if vtypes:
stype = verifierToSSAType(vtypes[0])
dtype = objtypes.verifierToSynthetic(vtypes[0])
cat = len(vtypes)
# clone() on an array type is known to always return that type, rather than any Object
if self.name == "clone" and target_tt[1] > 0:
dtype = target_tt
self.rval = parent.makeVariable(stype, origin=self)
self.returned = [self.rval] + [None]*(cat-1)
else:
self.rval, self.returned = None, []
# just use a fixed constraint until we can do interprocedural analysis
# output order is rval, exception, defined by BaseOp.getOutputs
env = parent.env
self.eout = ObjectConstraint.fromTops(env, [objtypes.ThrowableTT], [], nonnull=True)
self.eout_npe = ObjectConstraint.fromTops(env, [excepttypes.NullPtr], [], nonnull=True)
if self.rval is not None:
if self.rval.type == SSA_OBJECT:
supers, exact = objtypes.declTypeToActual(env, dtype)
self.rout = ObjectConstraint.fromTops(env, supers, exact)
else:
self.rout = constraints.fromVariable(env, self.rval)
else:
self.rout = None
def propagateConstraints(self, *incons):
if self.instruction[0] != 'invokestatic' and incons[0].isConstNull():
return throw(self.eout_npe)
return returnOrThrow(self.rout, self.eout)
# TODO - cleanup
class InvokeDynamic(BaseOp):
has_side_effects = True
def __init__(self, parent, desc, args):
super(InvokeDynamic, self).__init__(parent, args, makeException=True)
self.desc = desc
vtypes = parseMethodDescriptor(self.desc)[1]
dtype = None
if vtypes:
stype = verifierToSSAType(vtypes[0])
dtype = objtypes.verifierToSynthetic(vtypes[0])
cat = len(vtypes)
self.rval = parent.makeVariable(stype, origin=self)
self.returned = [self.rval] + [None]*(cat-1)
else:
self.rval, self.returned = None, []
# just use a fixed constraint until we can do interprocedural analysis
# output order is rval, exception, defined by BaseOp.getOutputs
env = parent.env
self.eout = ObjectConstraint.fromTops(env, [objtypes.ThrowableTT], [], nonnull=True)
if self.rval is not None:
if self.rval.type == SSA_OBJECT:
supers, exact = objtypes.declTypeToActual(env, dtype)
self.rout = ObjectConstraint.fromTops(env, supers, exact)
else:
self.rout = constraints.fromVariable(env, self.rval)
else:
self.rout = None
def propagateConstraints(self, *incons):
return returnOrThrow(self.rout, self.eout)

View File

@ -0,0 +1,21 @@
from .. import excepttypes
from ..constraints import ObjectConstraint, maybeThrow
from .base import BaseOp
class Monitor(BaseOp):
has_side_effects = True
def __init__(self, parent, args, isExit):
BaseOp.__init__(self, parent, args, makeException=True)
self.exit = isExit
self.env = parent.env
def propagateConstraints(self, x):
etypes = ()
if x.null:
etypes += (excepttypes.NullPtr,)
if self.exit and not x.isConstNull():
etypes += (excepttypes.MonState,)
eout = ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True)
return maybeThrow(eout)

View File

@ -0,0 +1,68 @@
from .. import excepttypes, objtypes
from ..constraints import ObjectConstraint, returnOrThrow, throw
from ..ssa_types import SSA_OBJECT
from .base import BaseOp
class New(BaseOp):
has_side_effects = True
def __init__(self, parent, name, inode_key):
super(New, self).__init__(parent, [], makeException=True)
self.env = parent.env
self.tt = objtypes.TypeTT(name, 0)
self.rval = parent.makeVariable(SSA_OBJECT, origin=self)
self.rval.uninit_orig_num = inode_key
def propagateConstraints(self):
eout = ObjectConstraint.fromTops(self.env, [], (excepttypes.OOM,), nonnull=True)
rout = ObjectConstraint.fromTops(self.env, [], [self.tt], nonnull=True)
return returnOrThrow(rout, eout)
class NewArray(BaseOp):
has_side_effects = True
def __init__(self, parent, param, baset):
super(NewArray, self).__init__(parent, [param], makeException=True)
self.baset = baset
self.rval = parent.makeVariable(SSA_OBJECT, origin=self)
self.tt = objtypes.withDimInc(baset, 1)
self.env = parent.env
def propagateConstraints(self, i):
if i.max < 0:
eout = ObjectConstraint.fromTops(self.env, [], (excepttypes.NegArrSize,), nonnull=True)
return throw(eout)
etypes = (excepttypes.OOM,)
if i.min < 0:
etypes += (excepttypes.NegArrSize,)
eout = ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True)
rout = ObjectConstraint.fromTops(self.env, [], [self.tt], nonnull=True)
return returnOrThrow(rout, eout)
class MultiNewArray(BaseOp):
has_side_effects = True
def __init__(self, parent, params, type_):
super(MultiNewArray, self).__init__(parent, params, makeException=True)
self.tt = type_
self.rval = parent.makeVariable(SSA_OBJECT, origin=self)
self.env = parent.env
def propagateConstraints(self, *dims):
for i in dims:
if i.max < 0: # ignore possibility of OOM here
eout = ObjectConstraint.fromTops(self.env, [], (excepttypes.NegArrSize,), nonnull=True)
return throw(eout)
etypes = (excepttypes.OOM,)
for i in dims:
if i.min < 0:
etypes += (excepttypes.NegArrSize,)
break
eout = ObjectConstraint.fromTops(self.env, [], etypes, nonnull=True)
rout = ObjectConstraint.fromTops(self.env, [], [self.tt], nonnull=True)
return returnOrThrow(rout, eout)

View File

@ -0,0 +1,46 @@
from .base import BaseOp
class Phi(object):
__slots__ = 'block dict rval'.split()
has_side_effects = False
def __init__(self, block, rval):
self.block = block # used in constraint propagation
self.dict = {}
self.rval = rval
assert rval is not None and rval.origin is None
rval.origin = self
def add(self, key, val):
assert key not in self.dict
assert val.type == self.rval.type
assert val is not None
self.dict[key] = val
@property
def params(self): return [self.dict[k] for k in self.block.predecessors]
def get(self, key): return self.dict[key]
def delete(self, key): del self.dict[key]
# Copy these over from BaseOp so we don't need to inherit
def replaceVars(self, rdict):
for k in self.dict:
self.dict[k] = rdict.get(self.dict[k], self.dict[k])
def getOutputs(self):
return self.rval, None, None
def removeOutput(self, var):
assert var == self.rval
self.rval = None
# An extended basic block can contain multiple throwing instructions
# but the OnException jump expects a single param. The solution is
# to create a dummy op that behaves like a phi function, selecting
# among the possible thrown exceptions in the block. This is always
# the last op in block.lines when there are exceptions.
# As this is a phi, params can be variable length
class ExceptionPhi(BaseOp):
def __init__(self, parent, params):
BaseOp.__init__(self, parent, params, makeException=True)

View File

@ -0,0 +1,25 @@
from .. import excepttypes, objtypes
from ..constraints import ObjectConstraint, maybeThrow, throw
from .base import BaseOp
class Throw(BaseOp):
def __init__(self, parent, args):
super(Throw, self).__init__(parent, args, makeException=True)
self.env = parent.env
def propagateConstraints(self, x):
if x.null:
t = x.types
exact = list(t.exact) + [excepttypes.NullPtr]
return throw(ObjectConstraint.fromTops(t.env, t.supers, exact, nonnull=True))
return throw(x)
# Dummy instruction that can throw anything
class MagicThrow(BaseOp):
def __init__(self, parent):
super(MagicThrow, self).__init__(parent, [], makeException=True)
self.eout = ObjectConstraint.fromTops(parent.env, [objtypes.ThrowableTT], [], nonnull=True)
def propagateConstraints(self):
return maybeThrow(self.eout)

View File

@ -0,0 +1,37 @@
from ..constraints import IntConstraint, return_
from . import bitwise_util
from .base import BaseOp
class Truncate(BaseOp):
def __init__(self, parent, arg, signed, width):
super(Truncate, self).__init__(parent, [arg])
self.signed, self.width = signed, width
self.rval = parent.makeVariable(arg.type, origin=self)
def propagateConstraints(self, x):
# get range of target type
w = self.width
intw = x.width
assert w < intw
M = 1<<w
mask = IntConstraint.const(intw, M-1)
x = bitwise_util.propagateAnd(x,mask)
# We have the mods in the range [0,M-1], but we want it in the range
# [-M/2, M/2-1] so we need to find the new min and max
if self.signed:
HM = M>>1
parts = [(i-M if i>=HM else i) for i in (x.min, x.max)]
if x.min <= HM-1 <= x.max:
parts.append(HM-1)
if x.min <= HM <= x.max:
parts.append(-HM)
assert -HM <= min(parts) <= max(parts) <= HM-1
return return_(IntConstraint.range(intw, min(parts), max(parts)))
else:
return return_(x)

View File

@ -0,0 +1,12 @@
from .. import excepttypes
from ..constraints import ObjectConstraint, maybeThrow
from .base import BaseOp
class TryReturn(BaseOp):
def __init__(self, parent, canthrow=True):
super(TryReturn, self).__init__(parent, [], makeException=True)
self.outExceptionCons = ObjectConstraint.fromTops(parent.env, [], (excepttypes.MonState,), nonnull=True)
def propagateConstraints(self):
return maybeThrow(self.outExceptionCons)

View File

@ -0,0 +1,77 @@
from collections import namedtuple as nt
from .. import floatutil as fu
from ..verifier import verifier_types as vtypes
slots_t = nt('slots_t', ('locals', 'stack'))
def _localsAsList(self): return [t[1] for t in sorted(self.locals.items())]
slots_t.localsAsList = property(_localsAsList)
# types
SSA_INT = 'int', 32
SSA_LONG = 'int', 64
SSA_FLOAT = 'float', fu.FLOAT_SIZE
SSA_DOUBLE = 'float', fu.DOUBLE_SIZE
SSA_OBJECT = 'obj',
def verifierToSSAType(vtype):
vtype_dict = {vtypes.T_INT:SSA_INT,
vtypes.T_LONG:SSA_LONG,
vtypes.T_FLOAT:SSA_FLOAT,
vtypes.T_DOUBLE:SSA_DOUBLE}
# These should never be passed in here
assert vtype.tag not in ('.new','.init')
vtype = vtypes.withNoConst(vtype)
if vtypes.objOrArray(vtype):
return SSA_OBJECT
elif vtype in vtype_dict:
return vtype_dict[vtype]
return None
# Note: This is actually an Extended Basic Block. A normal basic block has to end whenever there is
# an instruction that can throw. This means that there is a seperate basic block for every throwing
# method, which causes horrible performance, especially in a large method with otherwise linear code.
# The solution is to use extended basic blocks, which are like normal basic blocks except that they
# can contain multiple throwing instructions as long as every throwing instruction has the same
# handlers. Due to the use of SSA, we also require that there are no changes to the locals between the
# first and last throwing instruction.
class BasicBlock(object):
__slots__ = "key phis lines jump unaryConstraints predecessors inslots throwvars chpairs except_used locals_at_except".split()
def __init__(self, key):
self.key = key
self.phis = None # The list of phi statements merging incoming variables
self.lines = [] # List of operations in the block
self.jump = None # The exit point (if, goto, etc)
# Holds constraints (range and type information) for each variable in the block.
# If the value is None, this variable cannot be reached
self.unaryConstraints = None
# List of predecessor pairs in deterministic order
self.predecessors = []
# temp vars used during graph creation
self.inslots = None
self.throwvars = []
self.chpairs = None
self.except_used = None
self.locals_at_except = None
def filterVarConstraints(self, keepvars):
self.unaryConstraints = {k:v for k,v in self.unaryConstraints.items() if k in keepvars}
def removePredPair(self, pair):
self.predecessors.remove(pair)
for phi in self.phis:
del phi.dict[pair]
def replacePredPair(self, oldp, newp):
self.predecessors[self.predecessors.index(oldp)] = newp
for phi in self.phis:
phi.dict[newp] = phi.dict[oldp]
del phi.dict[oldp]
def __str__(self): # pragma: no cover
return 'Block ' + str(self.key)
__repr__ = __str__

View File

@ -0,0 +1,59 @@
import copy
from .ssa_types import slots_t
class ProcInfo(object):
def __init__(self, retblock, target):
self.retblock = retblock
self.target = target
self.jsrblocks = []
assert target is retblock.jump.target
def __str__(self): # pragma: no cover
return 'Proc{}<{}>'.format(self.target.key, ', '.join(str(b.key) for b in self.jsrblocks))
__repr__ = __str__
###########################################################################################
class ProcJumpBase(object):
@property
def params(self):
return [v for v in self.input.stack + self.input.localsAsList if v is not None]
# [v for v in self.input.stack if v] + [v for k, v in sorted(self.input.locals.items()) if v]
def replaceBlocks(self, blockDict):
self.target = blockDict.get(self.target, self.target)
def getExceptSuccessors(self): return ()
def getSuccessors(self): return self.getNormalSuccessors()
def getSuccessorPairs(self): return [(x,False) for x in self.getNormalSuccessors()]
def reduceSuccessors(self, pairsToRemove): return self
class ProcCallOp(ProcJumpBase):
def __init__(self, target, fallthrough, inslots, outslots):
self.fallthrough = fallthrough
self.target = target
self.input = inslots
self.output = outslots
for var in self.output.stack + self.output.locals.values():
if var is not None:
assert var.origin is None
var.origin = self
# def flatOutput(self): return [v for v in self.output.stack if v] + [v for k, v in sorted(self.output.locals.items()) if v]
def flatOutput(self): return self.output.stack + self.output.localsAsList
def getNormalSuccessors(self): return self.fallthrough, self.target
class DummyRet(ProcJumpBase):
def __init__(self, inslots, target):
self.target = target
self.input = inslots
def replaceVars(self, varDict):
newstack = [varDict.get(v, v) for v in self.input.stack]
newlocals = {k: varDict.get(v, v) for k, v in self.input.locals.items()}
self.input = slots_t(stack=newstack, locals=newlocals)
def getNormalSuccessors(self): return ()
def clone(self): return copy.copy(self) # target and input will be replaced later by calls to replaceBlocks/Vars

View File

@ -0,0 +1,7 @@
def thunk(initial):
stack = [initial]
while stack:
try:
stack.append(next(stack[-1]))
except StopIteration:
stack.pop()

View File

@ -0,0 +1,85 @@
from .verifier_types import T_ARRAY, T_BOOL, T_BYTE, T_CHAR, T_DOUBLE, T_FLOAT, T_INT, T_INVALID, T_LONG, T_OBJECT, T_SHORT, unSynthesizeType
_cat2tops = T_LONG, T_DOUBLE
def parseFieldDescriptors(desc_str, unsynthesize=True):
baseTypes = {'B':T_BYTE, 'C':T_CHAR, 'D':T_DOUBLE, 'F':T_FLOAT,
'I':T_INT, 'J':T_LONG, 'S':T_SHORT, 'Z':T_BOOL}
fields = []
while desc_str:
oldlen = len(desc_str)
desc_str = desc_str.lstrip('[')
dim = oldlen - len(desc_str)
if dim > 255:
raise ValueError('Dimension {} > 255 in descriptor'.format(dim))
if not desc_str:
raise ValueError('Descriptor contains [s at end of string')
if desc_str[0] == 'L':
end = desc_str.find(';')
if end == -1:
raise ValueError('Unmatched L in descriptor')
name = desc_str[1:end]
desc_str = desc_str[end+1:]
baset = T_OBJECT(name)
else:
if desc_str[0] not in baseTypes:
raise ValueError('Unrecognized code {} in descriptor'.format(desc_str[0]))
baset = baseTypes[desc_str[0]]
desc_str = desc_str[1:]
if dim:
# Hotspot considers byte[] and bool[] identical for type checking purposes
if unsynthesize and baset == T_BOOL:
baset = T_BYTE
baset = T_ARRAY(baset, dim)
elif unsynthesize:
# synthetics are only meaningful as basetype of an array
# if they are by themselves, convert to int.
baset = unSynthesizeType(baset)
fields.append(baset)
if baset in _cat2tops:
fields.append(T_INVALID)
return fields
# get a single descriptor
def parseFieldDescriptor(desc_str, unsynthesize=True):
rval = parseFieldDescriptors(desc_str, unsynthesize)
cat = 2 if (rval and rval[0] in _cat2tops) else 1
if len(rval) != cat:
raise ValueError('Incorrect number of fields in descriptor, expected {} but found {}'.format(cat, len(rval)))
return rval
# Parse a string to get a Java Method Descriptor
def parseMethodDescriptor(desc_str, unsynthesize=True):
if not desc_str.startswith('('):
raise ValueError('Method descriptor does not start with (')
# we need to split apart the argument list and return value
# this is greatly complicated by the fact that ) is a legal
# character that can appear in class names
lp_pos = desc_str.rfind(')') # this case will work if return type is not an object
if desc_str.endswith(';'):
lbound = max(desc_str.rfind(';', 1, -1), 1)
lp_pos = desc_str.find(')', lbound, -1)
if lp_pos < 0 or desc_str[lp_pos] != ')':
raise ValueError('Unable to split method descriptor into arguments and return type')
arg_str = desc_str[1:lp_pos]
rval_str = desc_str[lp_pos+1:]
args = parseFieldDescriptors(arg_str, unsynthesize)
rval = [] if rval_str == 'V' else parseFieldDescriptor(rval_str, unsynthesize)
return args, rval
# Adds self argument for nonstatic. Constructors must be handled seperately
def parseUnboundMethodDescriptor(desc_str, target, isstatic):
args, rval = parseMethodDescriptor(desc_str)
if not isstatic:
args = [T_OBJECT(target)] + args
return args, rval

View File

@ -0,0 +1,508 @@
import itertools
from .. import bytecode, error as error_types, opnames as ops
from .descriptors import parseFieldDescriptor, parseMethodDescriptor, parseUnboundMethodDescriptor
from .verifier_types import OBJECT_INFO, T_ADDRESS, T_ARRAY, T_DOUBLE, T_FLOAT, T_INT, T_INT_CONST, T_INVALID, T_LONG, T_NULL, T_OBJECT, T_UNINIT_OBJECT, T_UNINIT_THIS, decrementDim, exactArrayFrom, fullinfo_t, mergeTypes
class VerifierTypesState(object):
def __init__(self, stack, locals, masks):
self.stack = stack
self.locals = locals
self.masks = masks
def copy(self): return VerifierTypesState(self.stack, self.locals, self.masks)
def withExcept(self, t): return VerifierTypesState([t], self.locals, self.masks)
def pop(self, n):
if n == 0:
return []
self.stack, popped = self.stack[:-n], self.stack[-n:]
return popped
def push(self, vals):
self.stack = self.stack + list(vals)
def setLocal(self, i, v):
if len(self.locals) < i:
self.locals = self.locals + [T_INVALID]*(i - len(self.locals))
self.locals = self.locals[:i] + [v] + self.locals[i+1:]
new = frozenset([i])
self.masks = [(addr, old | new) for addr, old in self.masks]
def local(self, i):
if len(self.locals) <= i:
return T_INVALID
return self.locals[i]
def jsr(self, target):
self.masks = self.masks + [(target, frozenset())]
def replace(self, old, new):
self.stack = [(new if v == old else v) for v in self.stack]
mask = frozenset(i for i, v in enumerate(self.locals) if v == old)
self.locals = [(new if v == old else v) for v in self.locals]
self.masks = [(addr, oldmask | mask) for addr, oldmask in self.masks]
def invalidateNews(self):
# Doesn't need to update mask
self.stack = [(T_INVALID if v.tag == '.new' else v) for v in self.stack]
self.locals = [(T_INVALID if v.tag == '.new' else v) for v in self.locals]
def maskFor(self, called):
self.masks = self.masks[:]
target, mask = self.masks.pop()
while target != called:
target, mask = self.masks.pop()
return mask
def returnTo(self, called, jsrstate):
mask = self.maskFor(called)
# merge locals using mask
zipped = itertools.izip_longest(self.locals, jsrstate.locals, fillvalue=T_INVALID)
self.locals = [(x if i in mask else y) for i,(x,y) in enumerate(zipped)]
def merge(self, other, env):
old_triple = self.stack, self.locals, self.masks
assert len(self.stack) == len(other.stack)
self.stack = [mergeTypes(env, new, old) for old, new in zip(self.stack, other.stack)]
self.locals = [mergeTypes(env, new, old) for old, new in zip(self.locals, other.locals)]
while self.locals and self.locals[-1] == T_INVALID:
self.locals.pop()
# Merge Masks
last_match = -1
mergedmasks = []
for entry1, mask1 in self.masks:
for j, (entry2, mask2) in enumerate(other.masks):
if j > last_match and entry1 == entry2:
item = entry1, (mask1 | mask2)
mergedmasks.append(item)
last_match = j
self.masks = mergedmasks
return (self.stack, self.locals, self.masks) != old_triple
def stateFromInitialArgs(args): return VerifierTypesState([], args[:], [])
_invoke_ops = (ops.INVOKESPECIAL, ops.INVOKESTATIC, ops.INVOKEVIRTUAL, ops.INVOKEINTERFACE, ops.INVOKEINIT, ops.INVOKEDYNAMIC)
def _loadFieldDesc(cpool, ind):
target, name, desc = cpool.getArgsCheck('Field', ind)
return parseFieldDescriptor(desc)
def _loadMethodDesc(cpool, ind):
target, name, desc = cpool.getArgs(ind)
return parseMethodDescriptor(desc)
def _indexToCFMInfo(cpool, ind, typen):
actual = cpool.getType(ind)
# JVM_GetCPMethodClassNameUTF accepts both
assert actual == typen or actual == 'InterfaceMethod' and typen == 'Method'
cname = cpool.getArgs(ind)[0]
if cname.startswith('[') or cname.endswith(';'):
try:
return parseFieldDescriptor(cname)[0]
except ValueError as e:
return T_INVALID
else:
return T_OBJECT(cname)
# Instructions which pop a fixed amount
_popAmount = {
ops.ARRLOAD_OBJ: 2,
ops.ARRSTORE_OBJ: 3,
ops.ARRLOAD: 2,
ops.TRUNCATE: 1,
ops.LCMP: 4,
ops.IF_A: 1,
ops.IF_I: 1,
ops.IF_ACMP: 2,
ops.IF_ICMP: 2,
ops.SWITCH: 1,
ops.NEWARRAY: 1,
ops.ANEWARRAY: 1,
ops.ARRLEN: 1,
ops.THROW: 1,
ops.CHECKCAST: 1,
ops.INSTANCEOF: 1,
ops.MONENTER: 1,
ops.MONEXIT: 1,
ops.GETFIELD: 1,
ops.NOP: 0,
ops.CONSTNULL: 0,
ops.CONST: 0,
ops.LDC: 0,
ops.LOAD: 0,
ops.IINC: 0,
ops.GOTO: 0,
ops.JSR: 0,
ops.RET: 0,
ops.NEW: 0,
ops.GETSTATIC: 0,
}
# Instructions which pop a variable amount depending on whether type is category 2
_popAmountVar = {
ops.STORE: (1, 0),
ops.NEG: (1, 0),
ops.CONVERT: (1, 0),
ops.ADD: (2, 0),
ops.SUB: (2, 0),
ops.MUL: (2, 0),
ops.DIV: (2, 0),
ops.REM: (2, 0),
ops.XOR: (2, 0),
ops.OR: (2, 0),
ops.AND: (2, 0),
ops.FCMP: (2, 0),
ops.SHL: (1, 1),
ops.SHR: (1, 1),
ops.USHR: (1, 1),
ops.ARRSTORE: (1, 2),
}
# Generic stack codes
genericStackCodes = {
ops.POP: (1, []),
ops.POP2: (2, []),
ops.DUP: (1, [0, 0]),
ops.DUPX1: (2, [1, 0, 1]),
ops.DUPX2: (3, [2, 0, 1, 2]),
ops.DUP2: (2, [0, 1, 0, 1]),
ops.DUP2X1: (3, [1, 2, 0, 1, 2]),
ops.DUP2X2: (4, [2, 3, 0, 1, 2, 3]),
ops.SWAP: (2, [1, 0]),
}
def _getPopAmount(cpool, instr, method):
op = instr[0]
if op in _popAmount:
return _popAmount[op]
if op in _popAmountVar:
a, b = _popAmountVar[op]
cat = 2 if instr[1] in 'JD' else 1
return a * cat + b
if op in genericStackCodes:
return genericStackCodes[op][0]
if op == ops.MULTINEWARRAY:
return instr[2]
elif op == ops.RETURN:
return len(parseMethodDescriptor(method.descriptor)[1])
elif op in (ops.PUTFIELD, ops.PUTSTATIC):
args = len(_loadFieldDesc(cpool, instr[1]))
if op == ops.PUTFIELD:
args += 1
return args
elif op in _invoke_ops:
args = len(_loadMethodDesc(cpool, instr[1])[0])
if op != ops.INVOKESTATIC and op != ops.INVOKEDYNAMIC:
args += 1
return args
codes = dict(zip('IFJD', [T_INT, T_FLOAT, T_LONG, T_DOUBLE]))
def _getStackResult(cpool, instr, key):
op = instr[0]
if op in (ops.TRUNCATE, ops.LCMP, ops.FCMP, ops.ARRLEN, ops.INSTANCEOF):
return T_INT
elif op in (ops.ADD, ops.SUB, ops.MUL, ops.DIV, ops.REM, ops.XOR, ops.AND, ops.OR, ops.SHL, ops.SHR, ops.USHR, ops.NEG):
return codes[instr[1]]
elif op == ops.CONSTNULL:
return T_NULL
elif op == ops.CONST:
if instr[1] == 'I':
return T_INT_CONST(instr[2])
return codes[instr[1]]
elif op == ops.ARRLOAD:
return codes.get(instr[1], T_INT)
elif op == ops.CONVERT:
return codes[instr[2]]
elif op == ops.LDC:
return {
'Int': T_INT,
'Long': T_LONG,
'Float': T_FLOAT,
'Double': T_DOUBLE,
'String': T_OBJECT('java/lang/String'),
'Class': T_OBJECT('java/lang/Class'),
'MethodType': T_OBJECT('java/lang/invoke/MethodType'),
'MethodHandle': T_OBJECT('java/lang/invoke/MethodHandle'),
}[cpool.getType(instr[1])]
elif op == ops.JSR:
return T_ADDRESS(instr[1])
elif op in (ops.CHECKCAST, ops.NEW, ops.ANEWARRAY, ops.MULTINEWARRAY):
target = _indexToCFMInfo(cpool, instr[1], 'Class')
if op == ops.ANEWARRAY:
return T_ARRAY(target)
elif op == ops.NEW:
return T_UNINIT_OBJECT(key)
return target
elif op == ops.NEWARRAY:
return parseFieldDescriptor('[' + instr[1])[0]
elif op in (ops.GETFIELD, ops.GETSTATIC):
return _loadFieldDesc(cpool, instr[1])[0]
elif op in _invoke_ops:
out = _loadMethodDesc(cpool, instr[1])[1]
assert 0 <= len(out) <= 2
return out[0] if out else None
class InstructionNode(object):
__slots__ = "key code env class_ cpool instruction op visited changed offsetToIndex indexToOffset state out_state jsrTarget next_instruction returnedFrom successors pop_amount stack_push stack_code target_type isThisCtor".split()
def __init__(self, code, offsetToIndex, indexToOffset, key):
self.key = key
assert(self.key is not None) # if it is this will cause problems with origin tracking
self.code = code
self.env = code.class_.env
self.class_ = code.class_
self.cpool = self.class_.cpool
self.instruction = code.bytecode[key]
self.op = self.instruction[0]
self.visited, self.changed = False, False
# store for usage calculating JSRs, finding successor instructions and the like
self.offsetToIndex = offsetToIndex
self.indexToOffset = indexToOffset
self.state = None
# Fields to be assigned later
self.jsrTarget = None
self.next_instruction = None
self.returnedFrom = None
self.successors = None
self.pop_amount = -1
self.stack_push = []
self.stack_code = None
# for blockmaker
self.target_type = None
self.isThisCtor = False
self.out_state = None # store out state for JSR/RET instructions
self._precomputeValues()
def _removeInterface(self, vt):
if vt.tag == '.obj' and vt.extra is not None and self.env.isInterface(vt.extra, forceCheck=True):
return T_ARRAY(OBJECT_INFO, vt.dim)
return vt
def _precomputeValues(self):
# parsed_desc, successors
off_i = self.offsetToIndex[self.key]
self.next_instruction = self.indexToOffset[off_i+1] # None if end of code
op = self.instruction[0]
self.pop_amount = _getPopAmount(self.cpool, self.instruction, self.code.method)
# cache these, since they're not state dependent
result = _getStackResult(self.cpool, self.instruction, self.key)
# temporary hack
if op == ops.CHECKCAST:
result = self._removeInterface(result)
if result is not None:
self.stack_push = [result]
if result in (T_LONG, T_DOUBLE):
self.stack_push.append(T_INVALID)
if op in genericStackCodes:
self.stack_code = genericStackCodes[op][1]
if op == ops.NEW:
self.target_type = _indexToCFMInfo(self.cpool, self.instruction[1], 'Class')
# Now get successors
next_ = self.next_instruction
if op in (ops.IF_A, ops.IF_I, ops.IF_ICMP, ops.IF_ACMP):
self.successors = next_, self.instruction[2]
elif op in (ops.JSR, ops.GOTO):
self.successors = self.instruction[1],
elif op in (ops.RETURN, ops.THROW):
self.successors = ()
elif op == ops.RET:
self.successors = None # calculate it when the node is reached
elif op == ops.SWITCH:
opname, default, jumps = self.instruction
targets = (default,)
if jumps:
targets += zip(*jumps)[1]
self.successors = targets
else:
self.successors = next_,
def _getNewState(self, iNodes):
state = self.state.copy()
popped = state.pop(self.pop_amount)
# Local updates/reading
op = self.instruction[0]
if op == ops.LOAD:
state.push([state.local(self.instruction[2])])
if self.instruction[1] in 'JD':
state.push([T_INVALID])
elif op == ops.STORE:
for i, val in enumerate(popped):
state.setLocal(self.instruction[2] + i, val)
elif op == ops.IINC:
state.setLocal(self.instruction[1], T_INT) # Make sure to clobber constants
elif op == ops.JSR:
state.jsr(self.instruction[1])
elif op == ops.NEW:
# This should never happen, but better safe than sorry.
state.replace(self.stack_push[0], T_INVALID)
elif op == ops.INVOKEINIT:
old = popped[0]
if old.tag == '.new':
new = _indexToCFMInfo(self.cpool, self.instruction[1], 'Method')
else: # .init
new = T_OBJECT(self.class_.name)
self.isThisCtor = True
state.replace(old, new)
# Make sure that push happens after local replacement in case of new/invokeinit
if self.stack_code is not None:
state.push(popped[i] for i in self.stack_code)
elif op == ops.ARRLOAD_OBJ:
# temporary hack
result = self._removeInterface(decrementDim(popped[0]))
state.push([result])
elif op == ops.NEWARRAY or op == ops.ANEWARRAY:
arrt = self.stack_push[0]
size = popped[0].const
if size is not None:
arrt = exactArrayFrom(arrt, size)
state.push([arrt])
else:
state.push(self.stack_push)
if self.op in (ops.RET, ops.JSR):
state.invalidateNews()
self.out_state = state # store for later convienence
assert all(isinstance(vt, fullinfo_t) for vt in state.stack)
assert all(isinstance(vt, fullinfo_t) for vt in state.locals)
return state
def _mergeSingleSuccessor(self, other, newstate, iNodes, isException):
if self.op == ops.RET and not isException:
# Get the instruction before other
off_i = self.offsetToIndex[other.key]
jsrnode = iNodes[self.indexToOffset[off_i-1]]
jsrnode.returnedFrom = self.key
if jsrnode.visited: # if not, skip for later
newstate = newstate.copy()
newstate.returnTo(jsrnode.instruction[1], jsrnode.state)
else:
return
if not other.visited:
other.state = newstate.copy()
other.visited = other.changed = True
else:
changed = other.state.merge(newstate, self.env)
other.changed = other.changed or changed
def update(self, iNodes, exceptions):
assert self.visited
self.changed = False
newstate = self._getNewState(iNodes)
successors = self.successors
if self.op == ops.JSR and self.returnedFrom is not None:
iNodes[self.returnedFrom].changed = True
if successors is None:
assert self.op == ops.RET
called = self.state.local(self.instruction[1]).extra
temp = [n.next_instruction for n in iNodes.values() if (n.op == ops.JSR and n.instruction[1] == called)]
successors = self.successors = tuple(temp)
self.jsrTarget = called # store for later use in ssa creation
# Merge into exception handlers first
for (start, end, handler, except_info) in exceptions:
if start <= self.key < end:
self._mergeSingleSuccessor(handler, self.state.withExcept(except_info), iNodes, True)
if self.op == ops.INVOKEINIT: # two cases since the ctor may suceed or fail before throwing
self._mergeSingleSuccessor(handler, newstate.withExcept(except_info), iNodes, True)
# Now regular successors
for k in self.successors:
self._mergeSingleSuccessor(iNodes[k], newstate, iNodes, False)
def __str__(self): # pragma: no cover
lines = ['{}: {}'.format(self.key, bytecode.printInstruction(self.instruction))]
if self.visited:
lines.append('Stack: ' + ', '.join(map(str, self.state.stack)))
lines.append('Locals: ' + ', '.join(map(str, self.state.locals)))
if self.state.masks:
lines.append('Masks:')
lines += ['\t{}: {}'.format(entry, sorted(cset)) for entry, cset in self.state.masks]
else:
lines.append('\tunvisited')
return '\n'.join(lines) + '\n'
def verifyBytecode(code):
method, class_ = code.method, code.class_
args, rval = parseUnboundMethodDescriptor(method.descriptor, class_.name, method.static)
env = class_.env
# Object has no superclass to construct, so it doesn't get an uninit this
if method.isConstructor and class_.name != 'java/lang/Object':
assert args[0] == T_OBJECT(class_.name)
args[0] = T_UNINIT_THIS
assert len(args) <= 255 and len(args) <= code.locals
offsets = sorted(code.bytecode.keys())
offset_rmap = {v:i for i,v in enumerate(offsets)}
offsets.append(None) # Sentinel for end of code
iNodes = [InstructionNode(code, offset_rmap, offsets, key) for key in offsets[:-1]]
iNodeLookup = {n.key:n for n in iNodes}
keys = frozenset(iNodeLookup)
for raw in code.except_raw:
if not ((0 <= raw.start < raw.end) and (raw.start in keys) and
(raw.handler in keys) and (raw.end in keys or raw.end == code.codelen)):
keylist = sorted(keys) + [code.codelen]
msg = "Illegal exception handler: {}\nValid offsets are: {}".format(raw, ', '.join(map(str, keylist)))
raise error_types.VerificationError(msg)
def makeException(rawdata):
if rawdata.type_ind:
typen = class_.cpool.getArgsCheck('Class', rawdata.type_ind)
else:
typen = 'java/lang/Throwable'
return (rawdata.start, rawdata.end, iNodeLookup[rawdata.handler], T_OBJECT(typen))
exceptions = map(makeException, code.except_raw)
start = iNodes[0]
start.state = stateFromInitialArgs(args)
start.visited, start.changed = True, True
done = False
while not done:
done = True
for node in iNodes:
if node.changed:
node.update(iNodeLookup, exceptions)
done = False
return iNodes

View File

@ -0,0 +1,141 @@
from collections import namedtuple as nt
# Define types for Inference
# Extra stores address for .new and .address and class name for object types
# Const stores value for int constants and length for exact arrays. This isn't needed for normal verification but is
# useful for optimizing the code later.
fullinfo_t = nt('fullinfo_t', ['tag','dim','extra','const'])
valid_tags = ['.'+_x for _x in 'int float double long obj new init address byte short char boolean'.split()]
valid_tags = frozenset([None] + valid_tags)
def _makeinfo(tag, dim=0, extra=None, const=None):
assert tag in valid_tags
return fullinfo_t(tag, dim, extra, const)
T_INVALID = _makeinfo(None)
T_INT = _makeinfo('.int')
T_FLOAT = _makeinfo('.float')
T_DOUBLE = _makeinfo('.double')
T_LONG = _makeinfo('.long')
T_NULL = _makeinfo('.obj')
T_UNINIT_THIS = _makeinfo('.init')
T_BYTE = _makeinfo('.byte')
T_SHORT = _makeinfo('.short')
T_CHAR = _makeinfo('.char')
T_BOOL = _makeinfo('.boolean') # Hotspot doesn't have a bool type, but we can use this elsewhere
# types with arguments
def T_ADDRESS(entry):
return _makeinfo('.address', extra=entry)
def T_OBJECT(name):
return _makeinfo('.obj', extra=name)
def T_ARRAY(baset, newDimensions=1):
assert 0 <= baset.dim <= 255-newDimensions
return _makeinfo(baset.tag, baset.dim+newDimensions, baset.extra)
def T_UNINIT_OBJECT(origin):
return _makeinfo('.new', extra=origin)
def T_INT_CONST(val):
assert -0x80000000 <= val < 0x80000000
return _makeinfo(T_INT.tag, const=val)
OBJECT_INFO = T_OBJECT('java/lang/Object')
CLONE_INFO = T_OBJECT('java/lang/Cloneable')
SERIAL_INFO = T_OBJECT('java/io/Serializable')
THROWABLE_INFO = T_OBJECT('java/lang/Throwable')
def objOrArray(fi): # False on uninitialized
return fi.tag == '.obj' or fi.dim > 0
def unSynthesizeType(t):
if t in (T_BOOL, T_BYTE, T_CHAR, T_SHORT):
return T_INT
return t
def decrementDim(fi):
if fi == T_NULL:
return T_NULL
assert fi.dim
tag = unSynthesizeType(fi).tag if fi.dim <= 1 else fi.tag
return _makeinfo(tag, fi.dim-1, fi.extra)
def exactArrayFrom(fi, size):
assert fi.dim > 0
if size >= 0:
return _makeinfo(fi.tag, fi.dim, fi.extra, size)
return fi
def withNoDimension(fi):
return _makeinfo(fi.tag, 0, fi.extra)
def withNoConst(fi):
if fi.const is None:
return fi
return _makeinfo(fi.tag, fi.dim, fi.extra)
def _decToObjArray(fi):
return fi if fi.tag == '.obj' else T_ARRAY(OBJECT_INFO, fi.dim-1)
def mergeTypes(env, t1, t2):
if t1 == t2:
return t1
t1 = withNoConst(t1)
t2 = withNoConst(t2)
if t1 == t2:
return t1
# non objects must match exactly
if not objOrArray(t1) or not objOrArray(t2):
return T_INVALID
if t1 == T_NULL:
return t2
elif t2 == T_NULL:
return t1
if t1 == OBJECT_INFO or t2 == OBJECT_INFO:
return OBJECT_INFO
if t1.dim or t2.dim:
for x in (t1,t2):
if x in (CLONE_INFO,SERIAL_INFO):
return x
t1 = _decToObjArray(t1)
t2 = _decToObjArray(t2)
if t1.dim > t2.dim:
t1, t2 = t2, t1
if t1.dim == t2.dim:
res = mergeTypes(env, withNoDimension(t1), withNoDimension(t2))
return res if res == T_INVALID else _makeinfo('.obj', t1.dim, res.extra)
else: # t1.dim < t2.dim
return t1 if withNoDimension(t1) in (CLONE_INFO, SERIAL_INFO) else T_ARRAY(OBJECT_INFO, t1.dim)
else: # neither is array
if env.isInterface(t2.extra, forceCheck=True):
return OBJECT_INFO
return T_OBJECT(env.commonSuperclass(t1.extra, t2.extra))
# Make verifier types printable for easy debugging
def vt_toStr(self): # pragma: no cover
if self == T_INVALID:
return '.none'
elif self == T_NULL:
return '.null'
if self.tag == '.obj':
base = self.extra
elif self.extra is not None:
base = '{}<{}>'.format(self.tag, self.extra)
else:
base = self.tag
return base + '[]'*self.dim
fullinfo_t.__str__ = fullinfo_t.__repr__ = vt_toStr

View File

@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<http://www.gnu.org/philosophy/why-not-lgpl.html>.
<http://www.gnu.org/philosophy/why-not-lgpl.html>.

Some files were not shown because too many files have changed in this diff Show More