Module ExtValues


Require Import Reals.
Require Import Coqlib.
Require Import Integers.
Require Import Values.
Require Import Floats ExtFloats.
Require Import Lia.

Open Scope Z_scope.

Definition Z_abs_diff (x y : Z) := Z.abs (x - y).
Definition Z_abs_diff2 (x y : Z) :=
  if x <=? y then y - x else x - y.
Lemma Z_abs_diff2_correct :
  forall x y : Z, (Z_abs_diff x y) = (Z_abs_diff2 x y).
Proof.
  intros.
  unfold Z_abs_diff, Z_abs_diff2.
  unfold Z.leb.
  pose proof (Z.compare_spec x y) as Hspec.
  inv Hspec.
  - rewrite Z.abs_eq; lia.
  - rewrite Z.abs_neq; lia.
  - rewrite Z.abs_eq; lia.
Qed.

Inductive shift1_4 : Type :=
| SHIFT1 | SHIFT2 | SHIFT3 | SHIFT4.

Definition z_of_shift1_4 (x : shift1_4) :=
  match x with
  | SHIFT1 => 1
  | SHIFT2 => 2
  | SHIFT3 => 3
  | SHIFT4 => 4
  end.

Definition shift1_4_of_z (x : Z) :=
  if Z.eq_dec x 1 then Some SHIFT1
  else if Z.eq_dec x 2 then Some SHIFT2
  else if Z.eq_dec x 3 then Some SHIFT3
  else if Z.eq_dec x 4 then Some SHIFT4
  else None.

Lemma shift1_4_of_z_correct :
  forall z,
    match shift1_4_of_z z with
    | Some x => z_of_shift1_4 x = z
    | None => True
    end.
Proof.
  intro. unfold shift1_4_of_z.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  destruct (Z.eq_dec _ _); cbn; try congruence.
  trivial.
Qed.

Definition int_of_shift1_4 (x : shift1_4) :=
  Int.repr (z_of_shift1_4 x).

Definition is_bitfield stop start :=
  (Z.leb start stop)
    && (Z.geb start Z.zero)
    && (Z.ltb stop Int.zwordsize).

Definition extfz stop start v :=
  if is_bitfield stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vint w =>
      Vint (Int.shru (Int.shl w (Int.repr (Z.sub Int.zwordsize stop'))) (Int.repr (Z.sub Int.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.


Definition extfs stop start v :=
  if is_bitfield stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vint w =>
      Vint (Int.shr (Int.shl w (Int.repr (Z.sub Int.zwordsize stop'))) (Int.repr (Z.sub Int.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.

Definition zbitfield_mask stop start :=
  (Z.shiftl 1 (Z.succ stop)) - (Z.shiftl 1 start).

Definition bitfield_mask stop start :=
  Vint(Int.repr (zbitfield_mask stop start)).

Definition bitfield_maskl stop start :=
  Vlong(Int64.repr (zbitfield_mask stop start)).

Definition insf stop start prev fld :=
  let mask := bitfield_mask stop start in
  if is_bitfield stop start
  then
    Val.or (Val.and prev (Val.notint mask))
           (Val.and (Val.shl fld (Vint (Int.repr start))) mask)
  else Vundef.

Definition is_bitfieldl stop start :=
  (Z.leb start stop)
    && (Z.geb start Z.zero)
    && (Z.ltb stop Int64.zwordsize).

Definition extfzl stop start v :=
  if is_bitfieldl stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vlong w =>
      Vlong (Int64.shru' (Int64.shl' w (Int.repr (Z.sub Int64.zwordsize stop'))) (Int.repr (Z.sub Int64.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.


Definition extfsl stop start v :=
  if is_bitfieldl stop start
  then
    let stop' := Z.add stop Z.one in
    match v with
    | Vlong w =>
      Vlong (Int64.shr' (Int64.shl' w (Int.repr (Z.sub Int64.zwordsize stop'))) (Int.repr (Z.sub Int64.zwordsize (Z.sub stop' start))))
    | _ => Vundef
    end
  else Vundef.

Definition insfl stop start prev fld :=
  let mask := bitfield_maskl stop start in
  if is_bitfieldl stop start
  then
    Val.orl (Val.andl prev (Val.notl mask))
            (Val.andl (Val.shll fld (Vint (Int.repr start))) mask)
  else Vundef.

Fixpoint highest_bit (x : Z) (n : nat) : Z :=
  match n with
  | O => 0
  | S n1 =>
    let n' := Z.of_N (N_of_nat n) in
    if Z.testbit x n'
    then n'
    else highest_bit x n1
  end.

Definition int_highest_bit (x : int) : Z :=
  highest_bit (Int.unsigned x) (31%nat).


Definition int64_highest_bit (x : int64) : Z :=
  highest_bit (Int64.unsigned x) (63%nat).

Definition val_shrx (v1 v2: val): val :=
  match v1, v2 with
  | Vint n1, Vint n2 =>
     if Int.ltu n2 (Int.repr 31)
     then Vint(Int.shrx n1 n2)
     else Vundef
  | _, _ => Vundef
  end.

Definition val_shrxl (v1 v2: val): val :=
  match v1, v2 with
  | Vlong n1, Vint n2 =>
     if Int.ltu n2 (Int.repr 63)
     then Vlong(Int64.shrx' n1 n2)
     else Vundef
  | _, _ => Vundef
  end.

Remark modulus_fits_64: Int.modulus < Int64.max_unsigned.
Proof.
  compute.
  trivial.
Qed.

Remark unsigned64_repr :
  forall i,
    -1 < i < Int.modulus ->
    Int64.unsigned (Int64.repr i) = i.
Proof.
  intros i H.
  destruct H as [Hlow Hhigh].
  apply Int64.unsigned_repr.
  split. { lia. }
  pose proof modulus_fits_64.
  lia.
Qed.
  
Theorem divu_is_divlu: forall v1 v2 : val,
    Val.divu v1 v2 =
    match Val.divlu (Val.longofintu v1) (Val.longofintu v2) with
    | None => None
    | Some q => Some (Val.loword q)
    end.
Proof.
  intros.
  destruct v1; cbn; trivial.
  destruct v2; cbn; trivial.
  destruct i as [i_val i_range].
  destruct i0 as [i0_val i0_range].
  cbn.
  unfold Int.eq, Int64.eq, Int.zero, Int64.zero.
  cbn.
  rewrite Int.unsigned_repr by (compute; split; discriminate).
  rewrite (Int64.unsigned_repr 0) by (compute; split; discriminate).
  rewrite (unsigned64_repr i0_val) by assumption.
  destruct (zeq i0_val 0) as [ | Hnot0]; cbn; trivial.
  f_equal. f_equal.
  unfold Int.divu, Int64.divu. cbn.
  rewrite (unsigned64_repr i_val) by assumption.
  rewrite (unsigned64_repr i0_val) by assumption.
  unfold Int64.loword.
  rewrite Int64.unsigned_repr.
  reflexivity.
  destruct (Z.eq_dec i0_val 1).
  {subst i0_val.
   pose proof modulus_fits_64.
   rewrite Zdiv_1_r.
   lia.
  }
  destruct (Z.eq_dec i_val 0).
  { subst i_val. compute.
    split;
    intro ABSURD;
    discriminate ABSURD. }
  assert ((i_val / i0_val) < i_val).
  { apply Z_div_lt; lia. }
  split.
  { apply Z_div_pos; lia. }
  pose proof modulus_fits_64.
  lia.
Qed.
  
Theorem modu_is_modlu: forall v1 v2 : val,
    Val.modu v1 v2 =
    match Val.modlu (Val.longofintu v1) (Val.longofintu v2) with
    | None => None
    | Some q => Some (Val.loword q)
    end.
Proof.
  intros.
  destruct v1; cbn; trivial.
  destruct v2; cbn; trivial.
  destruct i as [i_val i_range].
  destruct i0 as [i0_val i0_range].
  cbn.
  unfold Int.eq, Int64.eq, Int.zero, Int64.zero.
  cbn.
  rewrite Int.unsigned_repr by (compute; split; discriminate).
  rewrite (Int64.unsigned_repr 0) by (compute; split; discriminate).
  rewrite (unsigned64_repr i0_val) by assumption.
  destruct (zeq i0_val 0) as [ | Hnot0]; cbn; trivial.
  f_equal. f_equal.
  unfold Int.modu, Int64.modu. cbn.
  rewrite (unsigned64_repr i_val) by assumption.
  rewrite (unsigned64_repr i0_val) by assumption.
  unfold Int64.loword.
  rewrite Int64.unsigned_repr.
  reflexivity.
  assert((i_val mod i0_val) < i0_val).
  apply Z_mod_lt.
  lia.
  split.
  { apply Z_mod_lt.
    lia. }
  pose proof modulus_fits_64.
  lia.
Qed.

Remark if_zlt_0_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt 0 Int.half_modulus) then x else y) = x.
Proof.
  reflexivity.
Qed.

Remark if_zlt_mone_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt (Int.unsigned Int.mone) Int.half_modulus) then x else y) = y.
Proof.
  reflexivity.
Qed.

Remark if_zlt_min_signed_half_modulus :
  forall T : Type,
  forall x y: T,
    (if (zlt (Int.unsigned (Int.repr Int.min_signed))
                     Int.half_modulus)
    then x
     else y) = y.
Proof.
  reflexivity.
Qed.

Lemma repr_unsigned64_repr:
  forall x, Int.repr (Int64.unsigned (Int64.repr x)) = Int.repr x.
Proof.
  intros.
  apply Int.eqm_samerepr.
  unfold Int.eqm.
  unfold Zbits.eqmod.
  pose proof (Int64.eqm_unsigned_repr x) as H64.
  unfold Int64.eqm in H64.
  unfold Zbits.eqmod in H64.
  destruct H64 as [k64 H64].
  change Int64.modulus with 18446744073709551616 in *.
  change Int.modulus with 4294967296.
  exists (-4294967296 * k64).
  set (y := Int64.unsigned (Int64.repr x)) in *.
  rewrite H64.
  clear H64.
  lia.
Qed.


Lemma big_unsigned_signed:
  forall x,
    (Int.unsigned x >= Int.half_modulus) ->
    (Int.signed x) = (Int.unsigned x) - Int.modulus.
Proof.
  destruct x as [xval xrange].
  intro BIG.
  unfold Int.signed, Int.unsigned in *. cbn in *.
  destruct (zlt _ _).
  lia.
  trivial.
Qed.


Lemma Z_quot_le: forall a b,
    0 <= a -> 1 <= b -> Z.quot a b <= a.
Proof.
  intros a b Ha Hb.
  destruct (Z.eq_dec b 1) as [Hb1 | Hb1].
  { (* b=1 *)
    subst.
    rewrite Z.quot_1_r.
    auto with zarith.
  }
  destruct (Z.eq_dec a 0) as [Ha0 | Ha0].
  { (* a=0 *)
    subst.
    rewrite Z.quot_0_l.
    auto with zarith.
    lia.
  }
  assert ((Z.quot a b) < a).
  {
    apply Z.quot_lt; lia.
  }
  auto with zarith.
Qed.


Require Import Coq.ZArith.Zquot.
Lemma Z_quot_pos_pos_bound: forall a b m,
    0 <= a <= m -> 1 <= b -> 0 <= Z.quot a b <= m.
Proof.
  intros.
  split.
  { rewrite <- (Z.quot_0_l b) by lia.
    apply Z_quot_monotone; lia.
  }
  apply Z.le_trans with (m := a).
  {
    apply Z_quot_le; tauto.
  }
  tauto.
Qed.
Lemma Z_quot_neg_pos_bound: forall a b m,
    m <= a <= 0 -> 1 <= b -> m <= Z.quot a b <= 0.
  intros.
  assert (0 <= - (a ÷ b) <= -m).
  {
    rewrite <- Z.quot_opp_l by lia.
    apply Z_quot_pos_pos_bound; lia.
  }
  lia.
Qed.

Lemma Z_quot_signed_pos_bound: forall a b,
    Int.min_signed <= a <= Int.max_signed -> 1 <= b ->
    Int.min_signed <= Z.quot a b <= Int.max_signed.
Proof.
  intros.
  destruct (Z_lt_ge_dec a 0).
  {
    split.
    { apply Z_quot_neg_pos_bound; lia. }
    { eapply Z.le_trans with (m := 0).
      { apply Z_quot_neg_pos_bound with (m := Int.min_signed); trivial.
        split. tauto. auto with zarith.
      }
      discriminate.
    }
  }
  { split.
    { eapply Z.le_trans with (m := 0).
      discriminate.
      apply Z_quot_pos_pos_bound with (m := Int.max_signed); trivial.
      split. lia. tauto.
    }
    { apply Z_quot_pos_pos_bound; lia.
    }
  }
Qed.

Lemma Z_quot_signed_neg_bound: forall a b,
    Int.min_signed <= a <= Int.max_signed -> b < -1 ->
    Int.min_signed <= Z.quot a b <= Int.max_signed.
Proof.
  change Int.min_signed with (-2147483648).
  change Int.max_signed with 2147483647.
  intros.

  replace b with (-(-b)) by auto with zarith.
  rewrite Z.quot_opp_r by lia.
  assert (-2147483647 <= (a ÷ - b) <= 2147483648).
  2: lia.
  
  destruct (Z_lt_ge_dec a 0).
  {
    replace a with (-(-a)) by auto with zarith.
    rewrite Z.quot_opp_l by lia.
    assert (-2147483648 <= - a ÷ - b <= 2147483647).
    2: lia.
    split.
    {
      rewrite Z.quot_opp_l by lia.
      assert (a ÷ - b <= 2147483648).
      2: lia.
      {
        apply Z.le_trans with (m := 0).
        rewrite <- (Z.quot_0_l (-b)) by lia.
        apply Z_quot_monotone; lia.
        discriminate.
      }
    }
    assert (- a ÷ - b < -a ).
    2: lia.
    apply Z_quot_lt; lia.
  }
  {
    split.
    { apply Z.le_trans with (m := 0).
      discriminate.
      rewrite <- (Z.quot_0_l (-b)) by lia.
      apply Z_quot_monotone; lia.
    }
    { apply Z.le_trans with (m := a).
      apply Z_quot_le.
      all: lia.
    }
  }
Qed.

Lemma sub_add_neg :
  forall x y, Val.sub x y = Val.add x (Val.neg y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int.sub_add_opp.
Qed.

Lemma neg_mul_distr_r :
  forall x y, Val.neg (Val.mul x y) = Val.mul x (Val.neg y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int.neg_mul_distr_r.
Qed.


Lemma negl_mull_distr_r :
  forall x y, Val.negl (Val.mull x y) = Val.mull x (Val.negl y).
Proof.
  destruct x; destruct y; cbn; trivial.
  f_equal.
  apply Int64.neg_mul_distr_r.
Qed.

Definition addx sh v1 v2 :=
  Val.add v2 (Val.shl v1 (Vint sh)).

Definition addxl sh v1 v2 :=
  Val.addl v2 (Val.shll v1 (Vint sh)).

Definition revsubx sh v1 v2 :=
  Val.sub v2 (Val.shl v1 (Vint sh)).

Definition revsubxl sh v1 v2 :=
  Val.subl v2 (Val.shll v1 (Vint sh)).

Definition minf v1 v2 :=
  match v1, v2 with
  | (Vfloat f1), (Vfloat f2) => Vfloat (ExtFloat.min f1 f2)
  | _, _ => Vundef
  end.

Definition maxf v1 v2 :=
  match v1, v2 with
  | (Vfloat f1), (Vfloat f2) => Vfloat (ExtFloat.max f1 f2)
  | _, _ => Vundef
  end.

Definition minfs v1 v2 :=
  match v1, v2 with
  | (Vsingle f1), (Vsingle f2) => Vsingle (ExtFloat32.min f1 f2)
  | _, _ => Vundef
  end.

Definition maxfs v1 v2 :=
  match v1, v2 with
  | (Vsingle f1), (Vsingle f2) => Vsingle (ExtFloat32.max f1 f2)
  | _, _ => Vundef
  end.

Definition invfs v1 :=
  match v1 with
  | (Vsingle f1) => Vsingle (ExtFloat32.inv f1)
  | _ => Vundef
  end.

Definition triple_op_float f v1 v2 v3 :=
  match v1, v2, v3 with
  | (Vfloat f1), (Vfloat f2), (Vfloat f3) => Vfloat (f f1 f2 f3)
  | _, _, _ => Vundef
  end.

Definition triple_op_single f v1 v2 v3 :=
  match v1, v2, v3 with
  | (Vsingle f1), (Vsingle f2), (Vsingle f3) => Vsingle (f f1 f2 f3)
  | _, _, _ => Vundef
  end.

Definition fmaddf := triple_op_float (fun f1 f2 f3 => Float.fma f2 f3 f1).
Definition fmaddfs := triple_op_single (fun f1 f2 f3 => Float32.fma f2 f3 f1).

Definition fmsubf := triple_op_float (fun f1 f2 f3 => Float.fma (Float.neg f2) f3 f1).
Definition fmsubfs := triple_op_single (fun f1 f2 f3 => Float32.fma (Float32.neg f2) f3 f1).

From Flocq Require Import Core Digits Operations Round Bracket Sterbenz
                          Binary Round_odd.
Require Import IEEE754_extra Zdiv Psatz Floats ExtFloats.

Definition div_approx_reals (a b : Z) (x : R) :=
    let q:=ZnearestE x in
    let r:=a-q*b in
    if r <? 0
    then q-1
    else q.

Lemma floor_ball1:
  forall x : R, forall y : Z,
    (Rabs (x - IZR y) < 1)%R -> Zfloor x = (y-1)%Z \/ Zfloor x = y.
Proof.
  intros x y BALL.
  apply Rabs_lt_inv in BALL.
  case (Rcompare_spec x (IZR y)); intro CMP.
  - left. apply Zfloor_imp.
    ring_simplify (y-1+1).
    rewrite minus_IZR. lra.
  - subst.
    rewrite Zfloor_IZR. right. reflexivity.
  - right. apply Zfloor_imp.
    rewrite plus_IZR. lra.
Qed.

Theorem div_approx_reals_correct:
  forall a b : Z, forall x : R,
    b > 0 ->
    (Rabs (x - IZR a/ IZR b) < 1/2)%R ->
    div_approx_reals a b x = (a/b)%Z.
Proof.
  intros a b x bPOS GAP.
  assert (0 < IZR b)%R by (apply IZR_lt ; lia).
  unfold div_approx_reals.
  pose proof (Znearest_imp2 (fun x => negb (Z.even x)) x) as NEAR.
  assert (Rabs (IZR (ZnearestE x) - IZR a/ IZR b) < 1)%R as BALL.
  { pose proof (Rabs_triang (IZR (ZnearestE x) - x)
                            (x - IZR a/ IZR b)) as TRI.
    ring_simplify (IZR (ZnearestE x) - x + (x - IZR a / IZR b))%R in TRI.
    lra.
  }
  clear GAP NEAR.
  rewrite Rabs_minus_sym in BALL.
  pose proof (floor_ball1 _ _ BALL) as FLOOR.
  clear BALL.
  rewrite Zfloor_div in FLOOR by lia.
  pose proof (Z_div_mod_eq_full a b) as DIV_MOD.
  assert (0 < b) as bPOS' by lia.
  pose proof (Z.mod_pos_bound a b bPOS') as MOD_BOUNDS.
  case Z.ltb_spec; intro; destruct FLOOR; lia.
Qed.

Definition my_div (a b : val) :=
  let b_d := Val.maketotal (Val.floatofintu b) in
  let invb_d := Val.floatofsingle (invfs (Val.maketotal (Val.singleofintu b))) in
  let alpha := fmsubf (Vfloat ExtFloat.one) invb_d b_d in
  let x := fmaddf invb_d alpha invb_d in
  Val.mulf (Val.maketotal (Val.floatofintu a)) x.

Definition int_abs i1 := Int.repr (Z.abs (Int.signed i1)).
Definition long_abs i1 := Int64.repr (Z.abs (Int64.signed i1)).

Definition int_absdiff i1 i2 :=
             Int.repr (Z_abs_diff (Int.signed i1) (Int.signed i2)).

Definition long_absdiff i1 i2 :=
             Int64.repr (Z_abs_diff (Int64.signed i1) (Int64.signed i2)).

Lemma int_absdiff_zero :
  forall i, int_abs i = int_absdiff i Int.zero.
Proof.
  intro. unfold int_abs, int_absdiff, Z_abs_diff.
  change (Int.signed Int.zero) with 0%Z.
  rewrite Z.sub_0_r.
  reflexivity.
Qed.

Lemma long_absdiff_zero :
  forall i, long_abs i = long_absdiff i Int64.zero.
Proof.
  intro. unfold long_abs, long_absdiff, Z_abs_diff.
  change (Int64.signed Int64.zero) with 0%Z.
  rewrite Z.sub_0_r.
  reflexivity.
Qed.

Definition double_op_int f v1 v2 :=
  match v1, v2 with
  | (Vint i1), (Vint i2) => Vint (f i1 i2)
  | _, _ => Vundef
  end.

Definition double_op_long f v1 v2 :=
  match v1, v2 with
  | (Vlong i1), (Vlong i2) => Vlong (f i1 i2)
  | _, _ => Vundef
  end.

Definition absdiff := double_op_int int_absdiff.
Definition absdiffl := double_op_long long_absdiff.

Definition abs v1 :=
  match v1 with
  | Vint x => Vint (Int.repr (Z.abs (Int.signed x)))
  | _ => Vundef
  end.

Definition absl v1 :=
  match v1 with
  | Vlong x => Vlong (Int64.repr (Z.abs (Int64.signed x)))
  | _ => Vundef
  end.

Lemma absdiff_zero_correct:
  forall v, abs v = absdiff v (Vint Int.zero).
Proof.
  intro. destruct v; cbn; try reflexivity.
  f_equal. unfold int_absdiff, Z_abs_diff.
  change (Int.unsigned Int.zero) with 0%Z.
  rewrite Z.sub_0_r.
  reflexivity.
Qed.

Lemma absdiffl_zero_correct:
  forall v, absl v = absdiffl v (Vlong Int64.zero).
Proof.
  intro. destruct v; cbn; try reflexivity.
  f_equal. unfold long_absdiff, Z_abs_diff.
  change (Int64.unsigned Int64.zero) with 0%Z.
  rewrite Z.sub_0_r.
  reflexivity.
Qed.

Remark absdiff_inject:
  forall f v1 v1' v2 v2'
  (INJ1 : Val.inject f v1 v1')
  (INJ2 : Val.inject f v2 v2'),
  Val.inject f (absdiff v1 v2) (absdiff v1' v2').
Proof.
  intros.
  inv INJ1; cbn; try constructor.
  inv INJ2; cbn; constructor.
Qed.

Remark absdiffl_inject:
  forall f v1 v1' v2 v2'
  (INJ1 : Val.inject f v1 v1')
  (INJ2 : Val.inject f v2 v2'),
  Val.inject f (absdiffl v1 v2) (absdiffl v1' v2').
Proof.
  intros.
  inv INJ1; cbn; try constructor.
  inv INJ2; cbn; constructor.
Qed.