Module Canary

Require Import Coqlib Wfsimpl Maps Errors Integers Values.
Require Import AST Linking Globalenvs.
Require Import Op Registers RTL Memdata.
Require Compopts.

Definition canary_chunk :=
  if Archi.canary64 then Mint64 else Mint32.

Definition canary_size :=
  if Archi.canary64 then 8 else 4.

Definition padding ofs al :=
  let md := Zmod ofs al in
  if zeq md 0
  then 0
  else al - md.

Definition canary_offset (f : function) :=
  f.(fn_stacksize) + (padding f.(fn_stacksize) (align_chunk canary_chunk)).

Definition canary_needed (f : function) : bool :=
  match getcanary with
  | Some _ =>
      zle 0 f.(fn_stacksize) &&
      zle (canary_offset f) Ptrofs.max_unsigned &&
      Compopts.stack_protector tt &&
      (Compopts.stack_protector_all tt ||
         zlt 0 f.(fn_stacksize))
  | None => false
  end.

Definition extra_canary_size (f : function) :=
  if canary_needed f
  then (padding f.(fn_stacksize) (align_chunk canary_chunk)) + canary_size
  else 0.

Definition code_append (st : code * node) fop : code * node :=
  let (co, pc) := st in
  (PTree.set pc (fop pc) co, Pos.succ pc).

Definition entry_sequence entrypoint offset tmpreg (code0 : code * node) : code * node :=
  match getcanary with
  | None => code0
  | Some canary_op =>
      let code1 := code_append code0
                     (fun pc => Iop canary_op nil tmpreg (Pos.succ pc)) in
      let code2 := code_append code1
                               (fun pc => Istore canary_chunk (Ainstack (Ptrofs.repr offset)) nil tmpreg (Pos.succ pc)) in
      code_append code2 (fun _ => Iop clearcanary nil tmpreg entrypoint)
  end.

Definition check_sequence crash_pc offset tmpreg insn (code0 : code * node) : code * node :=
  match getcanary with
  | None => code0
  | Some canary_op =>
      let tmpreg2 := Pos.succ tmpreg in
      let code1 := code_append code0
                               (fun pc => Iload TRAP canary_chunk (Ainstack (Ptrofs.repr offset)) nil tmpreg (Pos.succ pc)) in
      let code2 := code_append code1
                               (fun pc => Iop canary_op nil tmpreg2 (Pos.succ pc)) in
      let code3 := code_append code2
                               (fun pc => Icond (negate_condition canary_cmp) (tmpreg :: tmpreg2 :: nil) crash_pc (Pos.succ pc) (Some false)) in
      code_append code3 (fun _ => insn)
  end.

Definition check_insn insn :=
  match insn with
  | Ireturn _ | Itailcall _ _ _ => true
  | _ => false
  end.

Definition apply_check (crash_pc : node) offset tmpreg (code0 : code*node) (at_pc : node) (insn : instruction) : code*node :=
  if check_insn insn
     then
         let code' := check_sequence crash_pc offset tmpreg insn code0 in
         ((PTree.set at_pc (Inop (snd code0)) (fst code')), (snd code'))
     else code0.
  
Definition apply_checks crash_pc offset tmpreg original (code0 : code*node) : code*node :=
  PTree.fold (apply_check crash_pc offset tmpreg) original code0.

Definition crash_call :=
  (EF_builtin "stack_chk_fail" (mksignature nil Tvoid cc_default)).

Definition crash_sequence (code0 : code * node) : code * node :=
  let code1 := code_append code0 (fun pc => Ibuiltin crash_call nil BR_none (Pos.succ pc)) in
  code_append code1 (fun _ => Inop (snd code0)).
                           
Definition transf_function (f : function) : function :=
  let needed := canary_needed f in
  let tmpreg := Pos.succ (max_reg_function f) in
  let next_pc := Pos.succ (max_pc_function f) in
  let offset := canary_offset f in
  {| fn_sig := f.(fn_sig);
     fn_params := f.(fn_params);
     fn_code :=
      if needed
      then
        let code0 := (f.(fn_code), next_pc) in
        let code1 := entry_sequence f.(fn_entrypoint) offset tmpreg code0 in
        let code2 := crash_sequence code1 in
        let code3 := apply_checks (snd code1) offset tmpreg f.(fn_code) code2 in
        fst code3
      else f.(fn_code);
    fn_stacksize :=f.(fn_stacksize) + (extra_canary_size f);
    fn_entrypoint := if needed then next_pc else f.(fn_entrypoint) |}.

Definition transf_fundef (fd: fundef) : fundef :=
  transf_fundef transf_function fd.

Definition transf_program (p: program) : program :=
  transform_program transf_fundef p.