# Module CSE2

Require Import Coqlib Maps Errors Integers Floats Lattice Kildall.
Require Import Memory Registers Op RTL Maps CSE2deps.

Inductive sym_val : Type :=
| SMove (src : reg)
| SOp (op : operation) (args : list reg)

Definition eq_args (x y : list reg) : { x = y } + { x <> y } :=
list_eq_dec peq x y.

Definition eq_sym_val : forall x y : sym_val,
{x = y} + { x <> y }.
Proof.
generalize eq_operation.
generalize eq_args.
generalize peq.
generalize chunk_eq.
decide equality.
Defined.

Module RELATION <: SEMILATTICE_WITHOUT_BOTTOM.

Definition t := (PTree.t sym_val).
Definition eq (r1 r2 : t) :=
forall x, (PTree.get x r1) = (PTree.get x r2).

Definition top : t := PTree.empty sym_val.

Lemma eq_refl: forall x, eq x x.
Proof.
unfold eq.
intros; reflexivity.
Qed.

Lemma eq_sym: forall x y, eq x y -> eq y x.
Proof.
unfold eq.
intros; eauto.
Qed.

Lemma eq_trans: forall x y z, eq x y -> eq y z -> eq x z.
Proof.
unfold eq.
intros; congruence.
Qed.

Definition sym_val_beq (x y : sym_val) :=
if eq_sym_val x y then true else false.

Definition beq (r1 r2 : t) := PTree.beq sym_val_beq r1 r2.

Lemma beq_correct: forall r1 r2, beq r1 r2 = true -> eq r1 r2.
Proof.
unfold beq, eq. intros r1 r2 EQ x.
pose proof (PTree.beq_correct sym_val_beq r1 r2) as CORRECT.
destruct CORRECT as [CORRECTF CORRECTB].
pose proof (CORRECTF EQ x) as EQx.
clear CORRECTF CORRECTB EQ.
unfold sym_val_beq in *.
destruct (r1 ! x) as [R1x | ] in *;
destruct (r2 ! x) as [R2x | ] in *;
destruct (eq_sym_val R1x R2x) in *; congruence.
Qed.

Definition ge (r1 r2 : t) :=
forall x,
match PTree.get x r1 with
| None => True
| Some v => (PTree.get x r2) = Some v
end.

Lemma ge_refl: forall r1 r2, eq r1 r2 -> ge r1 r2.
Proof.
unfold eq, ge.
intros r1 r2 EQ x.
pose proof (EQ x) as EQx.
clear EQ.
destruct (r1 ! x).
- congruence.
- trivial.
Qed.

Lemma ge_trans: forall x y z, ge x y -> ge y z -> ge x z.
Proof.
unfold ge.
intros r1 r2 r3 GE12 GE23 x.
pose proof (GE12 x) as GE12x; clear GE12.
pose proof (GE23 x) as GE23x; clear GE23.
destruct (r1 ! x); trivial.
destruct (r2 ! x); congruence.
Qed.

Definition lub (r1 r2 : t) :=
PTree.combine
(fun ov1 ov2 =>
match ov1, ov2 with
| (Some v1), (Some v2) =>
if eq_sym_val v1 v2
then ov1
else None
| None, _
| _, None => None
end)
r1 r2.

Lemma ge_lub_left: forall x y, ge (lub x y) x.
Proof.
unfold ge, lub.
intros r1 r2 x.
rewrite PTree.gcombine by reflexivity.
destruct (_ ! _); trivial.
destruct (_ ! _); trivial.
destruct (eq_sym_val _ _); trivial.
Qed.

Lemma ge_lub_right: forall x y, ge (lub x y) y.
Proof.
unfold ge, lub.
intros r1 r2 x.
rewrite PTree.gcombine by reflexivity.
destruct (_ ! _); trivial.
destruct (_ ! _); trivial.
destruct (eq_sym_val _ _); trivial.
congruence.
Qed.

End RELATION.

Module DS := Dataflow_Solver(RB)(NodeSetForward).

Definition kill_sym_val (dst : reg) (sv : sym_val) :=
match sv with
| SMove src => if peq dst src then true else false
| SOp op args => List.existsb (peq dst) args
end.

Definition kill_reg (dst : reg) (rel : RELATION.t) :=
PTree.filter1 (fun x => negb (kill_sym_val dst x))
(PTree.remove dst rel).

Definition kill_sym_val_mem (sv: sym_val) :=
match sv with
| SMove _ => false
| SOp op _ => op_depends_on_memory op
| SLoad _ _ _ => true
end.

Definition kill_sym_val_store chunk addr args (sv: sym_val) :=
match sv with
| SMove _ => false
| SOp op _ => op_depends_on_memory op
end.

Definition kill_mem (rel : RELATION.t) :=
PTree.filter1 (fun x => negb (kill_sym_val_mem x)) rel.

Definition forward_move (rel : RELATION.t) (x : reg) : reg :=
match rel ! x with
| Some (SMove org) => org
| _ => x
end.

Definition kill_store1 chunk addr args rel :=
PTree.filter1 (fun x => negb (kill_sym_val_store chunk addr args x)) rel.

Definition kill_store chunk addr args rel :=
kill_store1 chunk addr (List.map (forward_move rel) args) rel.

Definition move (src dst : reg) (rel : RELATION.t) :=
PTree.set dst (SMove (forward_move rel src)) (kill_reg dst rel).

Definition find_op_fold op args (already : option reg) x sv :=
| None =>
match sv with
| (SOp op' args') =>
if (eq_operation op op') && (eq_args args args')
then Some x
else None
| _ => None
end
end.

Definition find_op (rel : RELATION.t) (op : operation) (args : list reg) :=
PTree.fold (find_op_fold op args) rel None.

| None =>
match sv with
if (chunk_eq chunk chunk') &&
(eq_args args args')
then Some x
else None
| _ => None
end
end.

Definition oper2 (op: operation) (dst : reg) (args : list reg)
(rel : RELATION.t) :=
let rel' := kill_reg dst rel in
PTree.set dst (SOp op (List.map (forward_move rel') args)) rel'.

Definition oper1 (op: operation) (dst : reg) (args : list reg)
(rel : RELATION.t) :=
if List.in_dec peq dst args
then kill_reg dst rel
else oper2 op dst args rel.

Definition oper (op: operation) (dst : reg) (args : list reg)
(rel : RELATION.t) :=
match find_op rel op (List.map (forward_move rel) args) with
| Some r => move r dst rel
| None => oper1 op dst args rel
end.

Definition gen_oper (op: operation) (dst : reg) (args : list reg)
(rel : RELATION.t) :=
match op, args with
| Omove, src::nil => move src dst rel
| _, _ => oper op dst args rel
end.

(dst : reg) (args : list reg) (rel : RELATION.t) :=
let rel' := kill_reg dst rel in

(dst : reg) (args : list reg) (rel : RELATION.t) :=
if List.in_dec peq dst args
then kill_reg dst rel

(dst : reg) (args : list reg) (rel : RELATION.t) :=
| Some r => move r dst rel
end.

Definition kill_builtin_res res rel :=
match res with
| BR r => kill_reg r rel
| _ => rel
end.

Definition apply_external_call ef (rel : RELATION.t) : RELATION.t :=
match ef with
| EF_builtin name sg
| EF_runtime name sg =>
match Builtins.lookup_builtin_function name sg with
| Some bf => rel
| None => kill_mem rel
end
| EF_malloc
| EF_external _ _
| EF_vstore _
| EF_free
| EF_memcpy _ _
| EF_inline_asm _ _ _ => kill_mem rel
| _ => rel
end.

Definition apply_instr instr (rel : RELATION.t) : RB.t :=
match instr with
| Inop _
| Icond _ _ _ _ _
| Ijumptable _ _ => Some rel
| Istore chunk addr args _ _ => Some (kill_store chunk addr args rel)
| Iop op args dst _ => Some (gen_oper op dst args rel)
| Icall _ _ _ dst _ => Some (kill_reg dst (kill_mem rel))
| Ibuiltin ef _ res _ => Some (kill_builtin_res res (apply_external_call ef rel))
| Itailcall _ _ _ | Ireturn _ => RB.bot
end.

Definition apply_instr' code (pc : node) (ro : RB.t) : RB.t :=
match ro with
| None => None
| Some x =>
match code ! pc with
| None => RB.bot
| Some instr => apply_instr instr x
end
end.

Definition forward_map (f : RTL.function) := DS.fixpoint
(RTL.fn_code f) RTL.successors_instr
(apply_instr' (RTL.fn_code f)) (RTL.fn_entrypoint f) (Some RELATION.top).

Definition forward_move_b (rb : RB.t) (x : reg) :=
match rb with
| None => x
| Some rel => forward_move rel x
end.

Definition subst_arg (fmap : option (PMap.t RB.t)) (pc : node) (x : reg) : reg :=
match fmap with
| None => x
| Some inv => forward_move_b (PMap.get pc inv) x
end.

Definition subst_args fmap pc := List.map (subst_arg fmap pc).

Definition find_op_in_fmap fmap pc op args :=
match fmap with
| None => None
| Some map =>
match PMap.get pc map with
| Some rel => find_op rel op args
| None => None
end
end.

match fmap with
| None => None
| Some map =>
match PMap.get pc map with
| None => None
end
end.

Definition transf_instr (fmap : option (PMap.t RB.t))
(pc: node) (instr: instruction) :=
match instr with
| Iop op args dst s =>
let args' := subst_args fmap pc args in
match (if is_trivial_op op then None else find_op_in_fmap fmap pc op args') with
| None => Iop op args' dst s
| Some src => Iop Omove (src::nil) dst s
end
let args' := subst_args fmap pc args in
| Some src => Iop Omove (src::nil) dst s
end
| Istore chunk addr args src s =>
Istore chunk addr (subst_args fmap pc args) (subst_arg fmap pc src) s
| Icall sig ros args dst s =>
Icall sig ros (subst_args fmap pc args) dst s
| Itailcall sig ros args =>
Itailcall sig ros (subst_args fmap pc args)
| Icond cond args s1 s2 i =>
Icond cond (subst_args fmap pc args) s1 s2 i
| Ijumptable arg tbl =>
Ijumptable (subst_arg fmap pc arg) tbl
| Ireturn (Some arg) =>
Ireturn (Some (subst_arg fmap pc arg))
| _ => instr
end.

Definition transf_function (f: function) : function :=
{| fn_sig := f.(fn_sig);
fn_params := f.(fn_params);
fn_stacksize := f.(fn_stacksize);
fn_code := PTree.map (transf_instr (forward_map f)) f.(fn_code);
fn_entrypoint := f.(fn_entrypoint) |}.

Definition transf_fundef (fd: fundef) : fundef :=
AST.transf_fundef transf_function fd.

Definition transf_program (p: program) : program :=
transform_program transf_fundef p.

Definition match_prog (p tp: RTL.program) :=
match_program (fun ctx f tf => tf = transf_fundef f) eq p tp.

Lemma transf_program_match:
forall p, match_prog p (transf_program p).
Proof.
intros. eapply match_transform_program; eauto.
Qed.