Module IntPromotionCommon


Require Import Coqlib Values Integers Op.
Require Import ZIntervalDomain.
Require Import OptionMonad.

Definition promote (sgn : bool) (n : int) : int64 :=
  Int64.repr (if sgn then Int.signed n else Int.unsigned n).

Definition promote_val (sgn : bool) (v : val) : val :=
  match v with
  | Vint i => Vlong (promote sgn i)
  | _ => Vundef
  end.

Definition promote_valb (sgn : bool) (a : bool * val) : val :=
  let (p, v) := a in
  if p then promote_val sgn v else v.

Definition in_unsigned_range (x : Z) : Prop :=
  0 <= x <= Int.max_unsigned.

Definition in_signed_range (x : Z) : Prop :=
  Int.min_signed <= x <= Int.max_signed.

Definition in_int_pos_range (x : Z) : Prop :=
  0 <= x <= Int.max_signed.

Record op_promotion : Set := mk_op_prom {
  op_prom_usg : option operation;
  op_prom_sgn : option operation;
  op_prom_args : list bool;
  op_prom_res : bool
}.

Definition op_prom_sgb (sgn : bool) (prom : op_promotion) : option operation :=
  if sgn then prom.(op_prom_sgn) else prom.(op_prom_usg).

Definition op_prom_None: op_promotion :=
  mk_op_prom None None nil false.

Record cond_promotion : Set := mk_cond_prom {
  cond_prom_usg : option condition;
  cond_prom_sgn : option condition;
}.

Definition cond_prom_sgb (sgn : bool) (prom : cond_promotion) : option condition :=
  if sgn then prom.(cond_prom_sgn) else prom.(cond_prom_usg).

Definition cond_prom_None: cond_promotion :=
  mk_cond_prom None None.

Definition promotable_op [F V] ge sp m (sgn : bool)
    (op op' : operation) (pargs : list bool) (pres : bool) (vl : list val) : Prop :=
  if_Some (@eval_operation F V ge sp op vl m) (fun v =>
    eval_operation ge sp op' (map (promote_valb sgn) (combine pargs vl)) m = Some (promote_valb sgn (pres, v))).

Definition promotable_op_strong [F V] ge sp m (sgn : bool)
    (op op' : operation) (pargs : list bool) (pres : bool) (vl : list val) : Prop :=
  eval_operation ge sp op' (map (promote_valb sgn) (combine pargs vl)) m =
  option_map (fun v => promote_valb sgn (pres, v)) (@eval_operation F V ge sp op vl m).

Definition sound_op_promotion
  (P : forall (sgn : bool) (op op' : operation) (pargs : list bool) (pres : bool) (vl : list val), Prop)
  (op : operation) (prom : op_promotion) (vl : list val) : Prop :=
  forall sgn : bool,
    if_Some (op_prom_sgb sgn prom) (fun op' => P sgn op op' prom.(op_prom_args) prom.(op_prom_res) vl).

Lemma sound_op_promotion_None P op vl:
  sound_op_promotion P op op_prom_None vl.
Proof.
  intros [|]; constructor.
Qed.

Definition promotable_cond m (sgn : bool) (cond cond' : condition) (vl : list val) : Prop :=
  if_Some (eval_condition cond vl m) (fun b =>
    eval_condition cond' (map (promote_val sgn) vl) m = Some b).

Definition promotable_cond_strong m (sgn : bool) (cond cond' : condition) (vl : list val) : Prop :=
  eval_condition cond' (map (promote_val sgn) vl) m = eval_condition cond vl m.

Definition sound_cond_promotion
  (P : forall (sgn : bool) (con cond' : condition) (vl : list val), Prop)
  (cond : condition) (prom : cond_promotion) (vl : list val) : Prop :=
  forall sgn : bool,
    if_Some (cond_prom_sgb sgn prom) (fun cond' => P sgn cond cond' vl).

Lemma sound_cond_promotion_None P cond vl:
  sound_cond_promotion P cond cond_prom_None vl.
Proof.
  intros [|]; constructor.
Qed.


Definition val_promotes_eq (v : val) : Prop :=
  promote_val false v = promote_val true v.

Definition list_val_promotes_eq : list bool -> list val -> Prop :=
  list_forall2 (fun (e : bool) v => if e then val_promotes_eq v else True).

Record op_promotion_ceq := mk_op_prom_ceq {
  op_prom0 :> op_promotion;
  op_prom_args_eq : list bool;
  op_prom_res_eq : bool;
}.

Record cond_promotion_ceq := mk_cond_prom_ceq {
  cond_prom0 :> cond_promotion;
  cond_prom_args_eq : list bool;
}.



Lemma int64_int_signed_range (i : int):
  Int64.min_signed <= Int.signed i <= Int64.max_signed.
Proof.
  specialize (Int.signed_range i).
  assert (Int64.min_signed <= Int.min_signed) by (cbn; lia).
  assert (Int64.max_signed >= Int.max_signed) by (cbn; lia).
  lia.
Qed.

Lemma promote_pos_eq i:
  in_int_pos_range (Int.unsigned i) ->
  promote false i = promote true i.
Proof.
  intros [_ Lt].
  unfold promote; f_equal.
  symmetry. apply Int.signed_eq_unsigned. assumption.
Qed.

Lemma ltu_wordsize_64 [i]
  (H : Int.ltu i Int.iwordsize = true):
  Int.ltu i Int64.iwordsize' = true.
Proof.
  apply zlt_true.
  apply Int.ltu_inv in H.
  assert (Int.unsigned Int.iwordsize <= Int.unsigned Int64.iwordsize'). {
    unfold Int.iwordsize, Int64.iwordsize'.
    rewrite !Int.unsigned_repr; cbn; lia.
  }
  lia.
Qed.

Definition cmp_Z (c : comparison) (x y : Z) : bool :=
  match c with
  | Ceq => zeq x y
  | Cne => negb (zeq x y)
  | Clt => if zlt x y then true else false
  | Cle => negb (if zlt y x then true else false)
  | Cgt => if zlt y x then true else false
  | Cge => negb (if zlt x y then true else false)
  end.

Lemma int_cmp_Z c x y:
  Int.cmp c x y = cmp_Z c (Int.signed x) (Int.signed y).
Proof.
  case c; simpl; rewrite ?Int.signed_eq; reflexivity.
Qed.

Lemma int_cmpu_Z c x y:
  Int.cmpu c x y = cmp_Z c (Int.unsigned x) (Int.unsigned y).
Proof.
  reflexivity.
Qed.

Lemma int64_cmp_Z c x y:
  Int64.cmp c x y = cmp_Z c (Int64.signed x) (Int64.signed y).
Proof.
  case c; simpl; rewrite ?Int64.signed_eq; reflexivity.
Qed.

Lemma int64_cmpu_Z c x y:
  Int64.cmpu c x y = cmp_Z c (Int64.unsigned x) (Int64.unsigned y).
Proof.
  reflexivity.
Qed.

Module Itv := Interval.
Module Itv32 := Int_Modulus_Interval.

Definition is_in_unsigned_range (itv : Interval.t) : bool :=
  (0 <=? Itv.itv_lo itv) && (Itv.itv_hi itv <=? Int.max_unsigned).

Lemma is_in_unsigned_range_sound i itv:
  Itv.zmatch i itv ->
  is_in_unsigned_range itv = true ->
  in_unsigned_range i.
Proof.
  intros [LO HI] H.
  apply andb_true_iff in H as (LO' & HI').
  split; lia.
Qed.

Definition is_in_signed_range (itv : Interval.t) : bool :=
  (Int.min_signed <=? Itv.itv_lo itv) && (Itv.itv_hi itv <=? Int.max_signed).

Lemma is_in_signed_range_sound i itv:
  Itv.zmatch i itv ->
  is_in_signed_range itv = true ->
  in_signed_range i.
Proof.
  intros [LO HI] H.
  apply andb_true_iff in H as (LO' & HI').
  split; lia.
Qed.

Definition is_int_pos (itv : Itv32.t) : bool :=
  itv.(Itv32.mod_itv).(Itv.itv_hi) <=? Int.max_signed.

Lemma is_int_pos_sound i itv:
  Itv.zmatch i itv.(Itv32.mod_itv) ->
  is_int_pos itv = true ->
  in_int_pos_range i.
Proof.
  intros [LO HI]; unfold is_int_pos, in_int_pos_range.
  specialize (Itv32.mod_lo itv) as LO'.
  lia.
Qed.

Program Definition sgn32_top : Interval.t :=
  {| Itv.itv_lo := Int.min_signed; Itv.itv_hi := Int.max_signed |}.

Lemma zmatch_sgn32_top i:
  Itv.zmatch (Int.signed i) sgn32_top.
Proof.
  apply Int.signed_range.
Qed.

Program Definition zinterval_sgn_of_usg (usg : Itv32.t) : Interval.t :=
  let usg := usg.(Itv32.mod_itv) in
  if Itv.itv_hi usg <? Int.half_modulus
  then usg
  else if Itv.itv_lo usg >=? Int.half_modulus
  then {| Itv.itv_lo := usg.(Itv.itv_lo) - Int.modulus;
          Itv.itv_hi := usg.(Itv.itv_hi) - Int.modulus |}
  else sgn32_top.
Next Obligation.
  destruct usg0 as [[] ? ?]; simpl in *; lia.
Qed.

Lemma zinterval_sgn_of_usg_correct (i : int) (usg : Itv32.t)
  (MATCH : Itv32.zmatch (Int.unsigned i) usg):
  Itv.zmatch (Int.signed i) (zinterval_sgn_of_usg usg).
Proof.
  rewrite <-(Int.repr_unsigned i); set (u := Int.unsigned i) in *.
  assert (UMOD: u mod Int_Modulus_Interval.modulus = u)
    by (apply Z.mod_small; apply Int.unsigned_range).
  unfold zinterval_sgn_of_usg; repeat autodestruct; intros; try solve [apply zmatch_sgn32_top];
    rewrite Int.signed_repr_eq, UMOD.
  - rewrite zlt_true. assumption.
    case MATCH as [_ ?]; lia.
  - rewrite zlt_false.
    + destruct MATCH; split; simpl; lia.
    + case MATCH as [? _]; lia.
Qed.