View Javadoc

1   // x86Assembler.java, created Mon Feb  5 23:23:19 2001 by joewhaley
2   // Copyright (C) 2001-3 John Whaley <jwhaley@alum.mit.edu>
3   // Licensed under the terms of the GNU LGPL; see COPYING for details.
4   package joeq.Assembler.x86;
5   
6   import java.util.HashMap;
7   import java.util.Iterator;
8   import java.util.Map;
9   import joeq.Allocator.DefaultCodeAllocator;
10  import joeq.Allocator.CodeAllocator.x86CodeBuffer;
11  import joeq.Main.jq;
12  import joeq.Memory.CodeAddress;
13  import jwutil.collections.LightRelation;
14  import jwutil.collections.Relation;
15  import jwutil.strings.Strings;
16  import jwutil.util.Assert;
17  
18  // Referenced classes of package joeq.Assembler.x86:
19  //            x86Constants, x86CodeBuffer, x86
20  
21  /***
22   * x86Assembler
23   *
24   * @author John Whaley <jwhaley@alum.mit.edu>
25   * @version $Id: x86Assembler.java 1941 2004-09-30 03:37:06Z joewhaley $
26   */
27  public class x86Assembler implements x86Constants {
28  
29      static class PatchInfo {
30  
31          int patchLocation, patchSize;
32          
33          PatchInfo(int patchLocation, int patchSize) {
34              this.patchLocation = patchLocation;
35              this.patchSize = patchSize;
36          }
37          
38          void patchTo(x86CodeBuffer mc, int target) {
39              if (patchSize == 4) {
40                  int v = mc.get4_endian(patchLocation - 4);
41                  Assert._assert(v == 0x44444444 || v == 0x55555555 || v == 0x66666666 || v == 0x77777777, "Location: "+Strings.hex(patchLocation-4)+" value: "+Strings.hex8(v));
42                  mc.put4_endian(patchLocation - 4, target - patchLocation);
43              } else if (patchSize == 1) {
44                  byte v = mc.get1(patchLocation - 1);
45                  Assert._assert(v == 0);
46                  Assert._assert(target - patchLocation <= 127);
47                  Assert._assert(target - patchLocation >= -128);
48                  mc.put1(patchLocation - 1, (byte)(target - patchLocation));
49              } else
50                  Assert.TODO();
51          }
52          
53          public String toString() {
54              return "loc:"+Strings.hex(patchLocation)+" size:"+patchSize;
55          }
56  
57      }
58  
59      static class AbsPatchInfo extends PatchInfo {
60  
61          AbsPatchInfo(int patchLocation, int patchSize) {
62              super(patchLocation, patchSize);
63          }
64          
65          void patchTo(x86CodeBuffer mc, int target) {
66              if (patchSize == 4) {
67                  int v = mc.get4_endian(patchLocation - 4);
68                  Assert._assert(v == 0x44444444 || v == 0x55555555 || v == 0x66666666 || v == 0x77777777, "Location: "+Strings.hex(patchLocation-4)+" value: "+Strings.hex8(v));
69                  mc.put4_endian(patchLocation - 4, mc.getStartAddress().offset(target).to32BitValue());
70              } else
71                  Assert.TODO();
72          }
73          
74          public String toString() {
75              return "loc:"+Strings.hex(patchLocation)+" size:"+patchSize+" (abs)";
76          }
77  
78      }
79      
80      public x86CodeBuffer getCodeBuffer() {
81          if (!branches_to_patch.isEmpty())
82              System.out.println("Error: unresolved forward branches!");
83          return mc;
84      }
85      public int getCurrentOffset() { return mc.getCurrentOffset(); }
86      public CodeAddress getCurrentAddress() { return mc.getCurrentAddress(); }
87      public CodeAddress getStartAddress() { return mc.getStartAddress(); }
88      public void patch1(int offset, byte value) { mc.put1(offset, value); }
89      public void patch4_endian(int offset, int value) { mc.put4_endian(offset, value); }
90  
91      public x86Assembler(int num_targets, int est_size, int offset, int alignment) {
92          mc = DefaultCodeAllocator.getCodeBuffer(est_size, offset, alignment);
93          if (TRACE) System.out.println("Assembler start address: "+mc.getCurrentAddress().stringRep());
94          branchtargetmap = new HashMap();
95          branches_to_patch = new LightRelation();
96      }
97  
98      public boolean containsTarget(Object target) {
99          return branchtargetmap.containsKey(target);
100     }
101     // backward branches
102     public void recordBranchTarget(Object target) {
103         Assert._assert(ip == mc.getCurrentOffset());
104         branchtargetmap.put(target, new Integer(ip));
105     }
106     public int getBranchTarget(Object target) {
107         Integer i = (Integer)branchtargetmap.get(target);
108         if (i == null) {
109             Assert.UNREACHABLE("Invalid branch target: "+target+" offset "+getCurrentOffset());
110         }
111         return i.intValue();
112     }
113     public Map getBranchTargetMap() {
114         return branchtargetmap;
115     }
116 
117     // forward branches
118     public void recordForwardBranch(int patchsize, Object target) {
119         if (TRACE) System.out.println("recording forward branch from "+Strings.hex(ip)+" (size "+patchsize+") to "+target);
120         branches_to_patch.add(target, new PatchInfo(ip, patchsize));
121     }
122     public void recordAbsoluteReference(int patchsize, Object target) {
123         if (TRACE) System.out.println("recording absolute reference from "+Strings.hex(ip)+" (size "+patchsize+") to "+target);
124         branches_to_patch.add(target, new AbsPatchInfo(ip, patchsize));
125     }
126     public void resolveForwardBranches(Object target) {
127         PatchInfo p;
128         Iterator it = branches_to_patch.getValues(target).iterator();
129         while (it.hasNext()) {
130             p = (PatchInfo)it.next();
131             if (TRACE) System.out.println("patching branch to "+target+" ("+p+") to point to "+Strings.hex(ip));
132             p.patchTo(mc, ip);
133         }
134         branches_to_patch.removeKey(target);
135     }
136 
137     // dynamic patch section
138     public void startDynamicPatch(int size) {
139         if (jq.SMP) {
140             int end = ip+size;
141             int mask = CACHE_LINE_SIZE-1;
142             while ((ip & mask) != (end & mask))
143                 emit1(x86.NOP);
144         }
145         dynPatchStart = ip;
146         dynPatchSize = size;
147     }
148     public void endDynamicPatch() {
149         Assert._assert(ip <= dynPatchStart + dynPatchSize);
150         while (ip < dynPatchStart + dynPatchSize) 
151             emit1(x86.NOP);
152         dynPatchSize = 0;
153     }
154 
155     // prefix
156     public void emitprefix(byte prefix) {
157         mc.add1(prefix);
158         ++ip;
159     }
160 
161     // special case instructions
162     public void emitPUSH_i(int imm) {
163         if (fits(imm, 8))
164             ip += x86.PUSH_i8.emit1_Imm8(mc, imm);
165         else
166             ip += x86.PUSH_i32.emit1_Imm32(mc, imm);
167     }
168     public void emit2_SHIFT_Mem_Imm8(x86 x, int off, int base, byte imm) {
169         if (base == ESP) {
170             if (off == 0) {
171                 if (imm == 1)
172                     ip += x.emit2_Once_SIB_EA(mc, ESP, ESP, SCALE_1);
173                 else
174                     ip += x.emit2_SIB_EA_Imm8(mc, ESP, ESP, SCALE_1, imm);
175             } else if (fits_signed(off, 8)) {
176                 if (imm == 1)
177                     ip += x.emit2_Once_SIB_DISP8(mc, ESP, ESP, SCALE_1, (byte)off);
178                 else
179                     ip += x.emit2_SIB_DISP8_Imm8(mc, ESP, ESP, SCALE_1, (byte)off, imm);
180             } else {
181                 if (imm == 1)
182                     ip += x.emit2_Once_SIB_DISP32(mc, ESP, ESP, SCALE_1, off);
183                 else
184                     ip += x.emit2_SIB_DISP32_Imm8(mc, ESP, ESP, SCALE_1, off, imm);
185             }
186         } else if (off == 0 && base != EBP) {
187             if (imm == 1)
188                 ip += x.emit2_Once_EA(mc, base);
189             else
190                 ip += x.emit2_EA_Imm8(mc, base, imm);
191         } else if (fits_signed(off, 8)) {
192             if (imm == 1)
193                 ip += x.emit2_Once_DISP8(mc, (byte)off, base);
194             else
195                 ip += x.emit2_DISP8_Imm8(mc, (byte)off, base, imm);
196         } else {
197             if (imm == 1)
198                 ip += x.emit2_Once_DISP32(mc, off, base);
199             else
200                 ip += x.emit2_DISP32_Imm8(mc, off, base, imm);
201         }
202     }
203 
204     public void emit2_SHIFT_Reg_Imm8(x86 x, int r1, byte imm) {
205         if (imm == 1)
206             ip += x.emit2_Once_Reg(mc, r1);
207         else
208             ip += x.emit2_Reg_Imm8(mc, r1, imm);
209     }
210 
211     // swap the order, because it is confusing.
212     public void emitSHLD_r_r_rc(int r1, int r2) {
213         ip += x86.SHLD_r_r_rc.emit3_Reg_Reg(mc, r2, r1);
214     }
215 
216     // swap the order, because it is confusing.
217     public void emitSHRD_r_r_rc(int r1, int r2) {
218         ip += x86.SHRD_r_r_rc.emit3_Reg_Reg(mc, r2, r1);
219     }
220 
221     // short
222     public void emitShort_Reg(x86 x, int r1) {
223         ip += x.emitShort_Reg(mc, r1);
224     }
225     public void emitShort_Reg_Imm(x86 x, int r1, int imm) {
226         ip += x.emitShort_Reg_Imm32(mc, r1, imm);
227     }
228 
229     // length 1
230     public void emit1(x86 x) {
231         ip += x.emit1(mc);
232     }
233     public void emit1_Imm8(x86 x, byte imm) {
234         ip += x.emit1_Imm8(mc, imm);
235     }
236     public void emit1_Imm16(x86 x, char imm) {
237         ip += x.emit1_Imm16(mc, imm);
238     }
239     public void emit1_Imm32(x86 x, int imm) {
240         ip += x.emit1_Imm32(mc, imm);
241     }
242 
243     // length 2
244     public void emit2(x86 x) {
245         ip += x.emit2(mc);
246     }
247     public void emit2_FPReg(x86 x, int r) {
248         ip += x.emit2_FPReg(mc, r);
249     }
250     public void emit2_Mem(x86 x, int imm) {
251         ip += x.emit2_Abs32(mc, imm);
252     }
253     public void emit2_Mem(x86 x, int off, int base) {
254         if (base == ESP) {
255             if (off == 0)
256                 ip += x.emit2_SIB_EA(mc, ESP, ESP, SCALE_1);
257             else if (fits_signed(off, 8))
258                 ip += x.emit2_SIB_DISP8(mc, ESP, ESP, SCALE_1, (byte)off);
259             else
260                 ip += x.emit2_SIB_DISP32(mc, ESP, ESP, SCALE_1, off);
261         } else if (off == 0 && base != EBP)
262             ip += x.emit2_EA(mc, base);
263         else if (fits_signed(off, 8))
264             ip += x.emit2_DISP8(mc, (byte)off, base);
265         else
266             ip += x.emit2_DISP32(mc, off, base);
267     }
268     public void emit2_Mem(x86 x, int base, int ind, int scale, int off) {
269         Assert._assert(ind != ESP);
270         Assert._assert(base != ESP);
271         if (off == 0)
272             ip += x.emit2_SIB_EA(mc, base, ind, scale);
273         else if (fits_signed(off, 8))
274             ip += x.emit2_SIB_DISP8(mc, base, ind, scale, (byte)off);
275         else
276             ip += x.emit2_SIB_DISP32(mc, base, ind, scale, off);
277     }
278     public void emit2_Mem_Imm(x86 x, int off, int base, int imm) {
279         if (base == ESP) {
280             if (off == 0)
281                 ip += x.emit2_SIB_EA_Imm32(mc, ESP, ESP, SCALE_1, imm);
282             else if (fits_signed(off, 8))
283                 ip += x.emit2_SIB_DISP8_Imm32(mc, ESP, ESP, SCALE_1, (byte)off, imm);
284             else
285                 ip += x.emit2_SIB_DISP32_Imm32(mc, ESP, ESP, SCALE_1, off, imm);
286         } else if (off == 0 && base != EBP)
287             ip += x.emit2_EA_Imm32(mc, base, imm);
288         else if (fits_signed(off, 8))
289             ip += x.emit2_DISP8_Imm32(mc, (byte)off, base, imm);
290         else
291             ip += x.emit2_DISP32_Imm32(mc, off, base, imm);
292     }
293     public void emit2_Reg(x86 x, int r1) {
294         ip += x.emit2_Reg(mc, r1);
295     }
296     public void emit2_Reg_Mem(x86 x, int r1, int addr) {
297         ip += x.emit2_Reg_Abs32(mc, r1, addr);
298     }
299     public void emit2_Reg_Mem(x86 x, int r1, int off, int base) {
300         if (base == ESP) {
301             if (off == 0)
302                 ip += x.emit2_Reg_SIB_EA(mc, r1, ESP, ESP, SCALE_1);
303             else if (fits_signed(off, 8))
304                 ip += x.emit2_Reg_SIB_DISP8(mc, r1, ESP, ESP, SCALE_1, (byte)off);
305             else
306                 ip += x.emit2_Reg_SIB_DISP32(mc, r1, ESP, ESP, SCALE_1, off);
307         } else if (off == 0 && base != EBP)
308             ip += x.emit2_Reg_EA(mc, r1, base);
309         else if (fits_signed(off, 8))
310             ip += x.emit2_Reg_DISP8(mc, r1, (byte)off, base);
311         else
312             ip += x.emit2_Reg_DISP32(mc, r1, off, base);
313     }
314     public void emit2_Reg_Mem(x86 x, int r1, int base, int ind, int scale, int off) {
315         if (off == 0)
316             ip += x.emit2_Reg_SIB_EA(mc, r1, base, ind, scale);
317         else if (fits_signed(off, 8))
318             ip += x.emit2_Reg_SIB_DISP8(mc, r1, base, ind, scale, (byte)off);
319         else
320             ip += x.emit2_Reg_SIB_DISP32(mc, r1, base, ind, scale, off);
321     }
322     public void emit2_Reg_Reg(x86 x, int r1, int r2) {
323         ip += x.emit2_Reg_Reg(mc, r1, r2);
324     }
325 
326     // length 3
327     public void emit3_Reg_Reg(x86 x, int r1, int r2) {
328         ip += x.emit3_Reg_Reg(mc, r1, r2);
329     }
330     public void emit3_Reg_Mem(x86 x, int r1, int addr) {
331         ip += x.emit3_Reg_Abs32(mc, r1, addr);
332     }
333     public void emit3_Reg_Mem(x86 x, int r1, int off, int base) {
334         if (base == ESP) {
335             if (off == 0)
336                 ip += x.emit3_Reg_SIB_EA(mc, r1, ESP, ESP, SCALE_1);
337             else if (fits_signed(off, 8))
338                 ip += x.emit3_Reg_SIB_DISP8(mc, r1, ESP, ESP, SCALE_1, (byte)off);
339             else
340                 ip += x.emit3_Reg_SIB_DISP32(mc, r1, ESP, ESP, SCALE_1, off);
341         } else if (off == 0 && base != EBP)
342             ip += x.emit3_Reg_EA(mc, r1, base);
343         else if (fits_signed(off, 8))
344             ip += x.emit3_Reg_DISP8(mc, r1, (byte)off, base);
345         else
346             ip += x.emit3_Reg_DISP32(mc, r1, off, base);
347     }
348     public void emit3_Reg_Mem(x86 x, int r1, int base, int ind, int mult, int off) {
349         if (off == 0)
350             ip += x.emit3_Reg_SIB_EA(mc, r1, base, ind, mult);
351         else if (fits_signed(off, 8))
352             ip += x.emit3_Reg_SIB_DISP8(mc, r1, base, ind, mult, (byte)off);
353         else
354             ip += x.emit3_Reg_SIB_DISP32(mc, r1, base, ind, mult, off);
355     }
356     
357     // arithmetic (with special EAX, Imm forms and 8-bit sign-extended immediates)
358     public void emitARITH_Mem_Imm(x86 x, int off, int base, int imm) {
359         if (base == ESP) {
360             if (off == 0) {
361                 if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
362                     ip += x.emit2_SIB_EA_SEImm8(mc, ESP, ESP, SCALE_1, (byte)imm);
363                 else
364                     ip += x.emit2_SIB_EA_Imm32(mc, ESP, ESP, SCALE_1, imm);
365             } else if (fits_signed(off, 8)) {
366                 if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
367                     ip += x.emit2_SIB_DISP8_SEImm8(mc, ESP, ESP, SCALE_1, (byte)off, (byte)imm);
368                 else
369                     ip += x.emit2_SIB_DISP8_Imm32(mc, ESP, ESP, SCALE_1, (byte)off, imm);
370             } else {
371                 if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
372                     ip += x.emit2_SIB_DISP32_SEImm8(mc, ESP, ESP, SCALE_1, off, (byte)imm);
373                 else
374                     ip += x.emit2_SIB_DISP32_Imm32(mc, ESP, ESP, SCALE_1, off, imm);
375             }
376         } else if (off == 0 && base != 5) {
377             if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
378                 ip += x.emit2_EA_SEImm8(mc, base, (byte)imm);
379             else
380                 ip += x.emit2_EA_Imm32(mc, base, imm);
381         } else if (fits_signed(off, 8)) {
382             if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
383                 ip += x.emit2_DISP8_SEImm8(mc, (byte)off, base, (byte)imm);
384             else
385                 ip += x.emit2_DISP8_Imm32(mc, (byte)off, base, imm);
386         } else {
387             if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
388                 ip += x.emit2_DISP32_SEImm8(mc, off, base, (byte)imm);
389             else
390                 ip += x.emit2_DISP32_Imm32(mc, off, base, imm);
391         }
392     }
393     public void emitARITH_Reg_Imm(x86 x, int r1, int imm) {
394         //if (r1 == EAX)
395         //    ip += x.emit1_RA_Imm32(mc, imm);
396         //else
397         if (x != x86.TEST_r_i32 && fits_signed(imm, 8))
398             ip += x.emit2_Reg_SEImm8(mc, r1, (byte)imm);
399         else
400             ip += x.emit2_Reg_Imm32(mc, r1, imm);
401     }
402     public void emitARITH_Reg_Reg(x86 x, int r1, int r2) {
403         ip += x.emit2_Reg_Reg(mc, r1, r2);
404     }
405     public void emitARITH_Reg_Mem(x86 x, int r1, int off, int base) {
406         if (base == ESP) {
407             if (off == 0)
408                 ip += x.emit2_Reg_SIB_EA(mc, r1, ESP, ESP, SCALE_1);
409             else if (fits_signed(off, 8))
410                 ip += x.emit2_Reg_SIB_DISP8(mc, r1, ESP, ESP, SCALE_1, (byte)off);
411             else
412                 ip += x.emit2_Reg_SIB_DISP32(mc, r1, ESP, ESP, SCALE_1, off);
413         } else if (off == 0 && base != 5)
414             ip += x.emit2_Reg_EA(mc, r1, base);
415         else if (fits_signed(off, 8))
416             ip += x.emit2_Reg_DISP8(mc, r1, (byte)off, base);
417         else
418             ip += x.emit2_Reg_DISP32(mc, r1, off, base);
419     }
420 
421     // conditional jumps
422     public void emitCJUMP_Back(x86 x, Object target) {
423         Assert._assert(x.length == 1);
424         int offset = getBranchTarget(target) - ip - 2;
425         if (offset >= -128) {
426             if (TRACE) System.out.println("Short cjump back from offset "+Strings.hex(ip+2)+" to "+target+" offset "+getBranchTarget(target)+" (relative offset "+Strings.shex(offset)+")");
427             ip += x.emitCJump_Short(mc, (byte)offset);
428         } else {
429             if (TRACE) System.out.println("Near cjump back from offset "+Strings.hex(ip+6)+" to "+target+" offset "+getBranchTarget(target)+" (relative offset "+Strings.shex(offset-4)+")");
430             ip += x.emitCJump_Near(mc, offset - 4);
431         }
432     }
433     public void emitCJUMP_Short(x86 x, byte offset) {
434         Assert._assert(x.length == 1);
435         ip += x.emitCJump_Short(mc, offset);
436     }
437     public void emitCJUMP_Forw_Short(x86 x, Object target) {
438         Assert._assert(x.length == 1);
439         ip += x.emitCJump_Short(mc, (byte)0);
440         recordForwardBranch(1, target);
441     }
442     public void emitCJUMP_Forw(x86 x, Object target) {
443         Assert._assert(x.length == 1);
444         ip += x.emitCJump_Near(mc, 0x66666666);
445         recordForwardBranch(4, target);
446     }
447 
448     // unconditional jumps
449     public void emitJUMP_Back(x86 x, Object target) {
450         Assert._assert(x.length == 1);
451         int offset = getBranchTarget(target) - ip - 2;
452         if(offset >= -128) {
453             if (TRACE) System.out.println("Short jump back from offset "+Strings.hex(ip+2)+" to "+target+" offset "+getBranchTarget(target)+" (relative offset "+Strings.shex(offset)+")");
454             ip += x.emitJump_Short(mc, (byte)offset);
455         } else {
456             if (TRACE) System.out.println("Near jump back from offset "+Strings.hex(ip+5)+" to "+target+" offset "+getBranchTarget(target)+" (relative offset "+Strings.shex(offset-3)+")");
457             ip += x.emitJump_Near(mc, offset - 3);
458         }
459     }
460     public void emitJUMP_Short(x86 x, byte offset) {
461         Assert._assert(x.length == 1);
462         ip += x.emitJump_Short(mc, offset);
463     }
464     public void emitJUMP_Forw_Short(x86 x, Object target) {
465         Assert._assert(x.length == 1);
466         ip += x.emitJump_Short(mc, (byte)0);
467         recordForwardBranch(1, target);
468     }
469     public void emitJUMP_Forw(x86 x, Object target) {
470         Assert._assert(x.length == 1);
471         ip += x.emitJump_Near(mc, 0x55555555);
472         recordForwardBranch(4, target);
473     }
474 
475     // relative calls
476     public void emitCALL_rel32(x86 x, int address) {
477         Assert._assert(x.length == 1);
478         ip += x.emitCall_Near(mc, address);
479     }
480     public void emitCALL_Back(x86 x, Object target) {
481         Assert._assert(x.length == 1);
482         int offset = getBranchTarget(target) - ip - 5;
483         ip += x.emitCall_Near(mc, offset);
484     }
485     public void emitCALL_Forw(x86 x, Object target) {
486         Assert._assert(x.length == 1);
487         ip += x.emitCall_Near(mc, 0x44444444);
488         recordForwardBranch(4, target);
489     }
490     
491     public void emitDATA(int data) {
492         mc.add4_endian(data);
493         ip += 4;
494     }
495 
496     public void skip(int nbytes) {
497         if (TRACE) System.out.println("skipping "+nbytes+" bytes");
498         mc.skip(nbytes);
499         ip += nbytes;
500     }
501     
502     public void setEntrypoint() {
503         mc.setEntrypoint();
504     }
505     
506     public static boolean fits(int val, int bits) {
507         val >>= bits - 1;
508         return val == 0;
509     }
510 
511     public static boolean fits_signed(int val, int bits) {
512         val >>= bits - 1;
513         return val == 0 || val == -1;
514     }
515 
516     public static /*final*/ boolean TRACE = false;
517     
518     private int ip;                     // current instruction pointer
519     private x86CodeBuffer mc;           // code repository
520     private Map/*<Object,Integer>*/ branchtargetmap;
521     private Relation/*<Object,Set<PatchInfo>>*/ branches_to_patch;
522     private int dynPatchStart, dynPatchSize;
523 }