Module RTLTunneling


Branch tunneling for the RTL representation

Require Import Coqlib Maps Errors.
Require Import AST.
Require Import RTL.


Definition UF := PTree.t (node * Z).

Axiom branch_target: RTL.function -> UF.
Extract Constant branch_target => "RTLTunnelingaux.branch_target".

Local Open Scope error_monad_scope.

Definition get (td: UF) (pc: node): node*Z :=
  match td!pc with
  | Some (t,d) => (t,Z.abs d)
  | None => (pc,0)
  end.


Definition target (td: UF) (pc: node): node := fst (get td pc).
Coercion target: UF >-> Funclass.


Definition check_included (td: UF) (c: code): option instruction
  := PTree.fold (fun (ok:option instruction) pc _ => if ok then c!pc else None) td (Some (Inop xH)).

Definition check_instr (td: UF) (pc: node) (i: instruction): res unit :=
  match td!pc with
  | None => OK tt
  | Some (tpc, dpc) =>
      let dpc := Z.abs dpc in
      match i with
      | Inop s =>
          let (ts,ds) := get td s in
          if peq tpc ts then
            if zlt ds dpc then OK tt
            else Error (msg "bad distance in Inop")
          else Error (msg "invalid skip of Inop")
      | Icond _ _ ifso ifnot _ =>
          let (tso,dso) := get td ifso in
          let (tnot,dnot) := get td ifnot in
          if peq tpc tso then
            if peq tpc tnot then
              if zlt dso dpc then
                if zlt dnot dpc then OK tt
                else Error (msg "bad distance on else branch")
              else Error (msg "bad distance on then branch")
            else Error (msg "invalid skip of else branch")
          else Error (msg "invalid skip of then branch")
      | _ => Error (msg "cannot skip this instruction")
      end
  end.

Definition check_code (td: UF) (c: code): res unit :=
  PTree.fold (fun ok pc i => do _ <- ok; check_instr td pc i) c (OK tt).


Definition tunnel_instr (t: node -> node) (i: instruction) : instruction :=
  match i with
  | Inop s => Inop (t s)
  | Iop op args res s => Iop op args res (t s)
  | Iload trap chunk addr args dst s => Iload trap chunk addr args dst (t s)
  | Istore chunk addr args src s => Istore chunk addr args src (t s)
  | Icall sig ros args res s => Icall sig ros args res (t s)
  | Ibuiltin ef args res s => Ibuiltin ef args res (t s)
  | Icond cond args ifso ifnot info =>
      let ifso' := t ifso in
      let ifnot' := t ifnot in
      if peq ifso' ifnot'
      
      then Inop ifso'
      else Icond cond args ifso' ifnot' info
  | Ijumptable arg tbl => Ijumptable arg (List.map t tbl)
  | _ => i
  end.

Definition tunnel_function (f: RTL.function): res RTL.function :=
  let td := branch_target f in
  let c := fn_code f in
  if check_included td c then
    do _ <- check_code td c ; OK
    (mkfunction
      (fn_sig f)
      (fn_params f)
      (fn_stacksize f)
      (PTree.map1 (tunnel_instr td) c)
      (td (fn_entrypoint f)))
  else Error (msg "Some node of the union-find is not in the CFG").

Definition tunnel_fundef (f: fundef): res fundef :=
  transf_partial_fundef tunnel_function f.

Definition transf_program (p: program): res program :=
  transform_partial_program tunnel_fundef p.