Module CFC


Require Import Coqlib Maps List Integers Errors.
Import ListNotations.
Require Import AST.
Require Import RTL RTLpast Op Registers.
Require Import CounterMeasures.
Require Compopts.

Open Scope error_monad_scope.

Add intra-procedural control-flow checking


Parameter ext_compute_sigs : code -> (PTree.t int * PTree.t unit).

Analyse the graph to create signatures
Section compute_sigs.

Defensive validator

  Definition is_join joins pc :=
    match joins!pc with Some tt => true | _ => false end.

  Definition correct_sigs (cfg: code) (sigs: PTree.t int) (joins: PTree.t unit) : Errors.res unit :=
    PTree.fold
      (fun acc pc i =>
         do tt <- acc;
         match i with
         | Inop pc1 | Iop _ _ _ pc1
         | Iload _ _ _ _ _ pc1 | Istore _ _ _ _ pc1
         | Icall _ _ _ _ pc1 | Ibuiltin _ _ _ pc1
         | Iassert _ _ pc1 =>
             if is_join joins pc1 then OK tt
             else match sigs!pc, sigs!pc1 with
                  | None, None => OK tt
                  | Some sig1, Some sig2 =>
                      if Int.eq_dec sig1 sig2 then OK tt
                      else Error [MSG "signature mismatch: l"; POS pc; MSG " and l"; POS pc1]
                  | _, _ => Error [MSG "signature mismatch: l"; POS pc; MSG " and l"; POS pc1]
                  end
         | Ijumptable _ pcs => Error (msg "TODO jumptables not supported")
         | Icond _ _ _ _ _ | Itailcall _ _ _ | Ireturn _ => OK tt
         end) cfg (OK tt).

  Definition robust_sigs (cfg: code) (sigs: PTree.t int) : Errors.res unit :=
    PTree.fold
      (fun acc pc i =>
         do tt <- acc;
         match i with
         | Icond _ _ pc1 pc2 _ =>
             match sigs!pc1, sigs!pc2 with
             | Some sig1, Some sig2 =>
                 if peq pc1 pc2 || negb (Int.eq_dec sig1 sig2) then OK tt
                 else Error [MSG "same signatures for cond: l"; POS pc1; MSG ", l"; POS pc2]
             | _, _ => Error [MSG "missing signature for cond: l"; POS pc1; MSG ", l"; POS pc2]
             end
         | _ => OK tt
         end) cfg (OK tt).

  Definition compute_sigs f :=
    let (sigs, joins) := ext_compute_sigs (fn_code f) in
    do tt <- correct_sigs (fn_code f) sigs joins;
    do tt <- robust_sigs (fn_code f) sigs;
    OK (sigs, joins).

End compute_sigs.

Actually transform the graph

Section transfer_instruction.

Variable sigs : PTree.t int.
Variable joins : PTree.t unit.
Variable gsr rts : reg.

Definition get_sig pc :=
  match sigs!pc with
  | Some sig => sig
  | None => Int.zero
  end.

Transfer join point
Definition transf_join pc1 pc2 (i: node -> instruction) : seqinstr OneExit :=
  if is_join joins pc2 then
    Ssinstr i;; Seinstr (Iop (Oxorimm (Int.xor (get_sig pc1) (get_sig pc2))) [gsr] gsr)
  else Seinstr i.

Definition set_xor_gsr {e} c rs sig sig1 sig2 pred : _ -> seqinstr e :=
  Ssmerge2 c rs
    (Seinstr (Iop (Ointconst (Int.xor sig sig1)) [] rts))
    (Seinstr (Iop (Ointconst (Int.xor sig sig2)) [] rts))
    pred.

Definition add_test sig : seqinstr OneExit :=
  Ssinstr (Iop Oxor [gsr;rts] gsr) ;;
  Ssmerge1 (Ccompimm Cne sig) [gsr] Secatch (Some false);;
  Seinstr (Ibuiltin (EF_observe [Tint]) [BA gsr] BR_none).

Definition transf_cond pc c rs pc1 pc2 pred :=
  let sig := get_sig pc in
  let sig1 := get_sig pc1 in
  let sig2 := get_sig pc2 in
  Ssinstr (Iassert (Ccompimm Ceq sig) [AA_reg gsr]);;
  set_xor_gsr c rs sig sig1 sig2 pred;;
  Ssinstr (Ibuiltin (EF_observe (type_of_condition c)) (List.map (@BA _) rs) BR_none);;
  Second c rs (add_test sig1) (add_test sig2) pred.

Definition transf_return pc r :=
  let sig := get_sig pc in
  if (Compopts.intra_cfc_return tt) then
    Ssinstr (Iassert (Ccompimm Ceq sig) [AA_reg gsr]);;
    Ssmerge1 (Ccompimm Cne sig) [gsr] Secatch (Some false);;
    Sereturn (Ireturn r)
  else Sereturn (Ireturn r).

Definition transf_nop pc pc1 :=
  transf_join pc pc1 Inop.
Definition transf_op pc op rs r pc1 :=
  transf_join pc pc1 (Iop op rs r).
Definition transf_load pc trp chk addr rs r pc1 :=
  transf_join pc pc1 (Iload trp chk addr rs r).
Definition transf_store pc chk addr rs r pc1 :=
  transf_join pc pc1 (Istore chk addr rs r).
Definition transf_call pc sig ros rs r pc1 :=
  transf_join pc pc1 (Icall sig ros rs r).
Definition transf_builtin pc ef rs r pc1 :=
  transf_join pc pc1 (Ibuiltin ef rs r).

End transfer_instruction.

Definition do_protect f :=
  in_dec eq_function_attr (Harden CFC) (fn_attr f) || Compopts.intra_cfc_all tt.

Definition transf_function (f : function) :=
  do res <-
    if do_protect f then
      let maxr := max_reg_function f in
      let gsr := Pos.succ maxr in
      let rts := Pos.succ gsr in
      do (sigs, joins) <- compute_sigs f;
      OK (transf_function
            (Seinstr (Iop (Ointconst (get_sig sigs (fn_entrypoint f))) [] gsr))
            (transf_nop sigs joins gsr) (transf_op sigs joins gsr)
            (transf_load sigs joins gsr) (transf_store sigs joins gsr)
            (transf_call sigs joins gsr) transf_tailcall_default (transf_builtin sigs joins gsr)
            (transf_cond sigs gsr rts)
            transf_jumptable_default (transf_return sigs gsr)
            f)
    else OK (fn_code f, fn_entrypoint f);
  OK (mkfunction (fn_sig f) (fn_params f) (fn_stacksize f) (fst res) (snd res) (fn_attr f)).

Definition transf_fundef (fd: fundef) :=
  AST.transf_partial_fundef transf_function fd.

Definition transf_program (p: program) :=
  AST.transform_partial_program transf_fundef p.