Module ExtValues


Require Import Coqlib.
Require Import Integers.
Require Import Values.
Require Import Lia.
Require ExtFloats.

Open Scope Z_scope.

Fixpoint pos_highest_bit (x : positive) (already : nat) : nat :=
  match x with
  | xH => already
  | xI x' | xO x' => pos_highest_bit x' (S already)
  end.

Lemma pos_highest_bit_range :
  forall x p (RANGE : (x < Pos.of_nat (2^p)%nat)%positive) already,
    (already <= pos_highest_bit x already < already + p)%nat.
Proof.
  induction x; intros; cbn.
  all: destruct p; cbn in RANGE; [lia|idtac].
  3: lia.
  all: rewrite Nat.add_0_r in RANGE.
  all: replace (2^p + 2^p)%nat with (2 * (2^p))%nat in RANGE by lia.
  {
    assert (S already <= pos_highest_bit x (S already) < S already + p)%nat.
    { apply IHx. lia. }
    lia.
  }
  {
    assert (S already <= pos_highest_bit x (S already) < S already + p)%nat.
    { apply IHx. lia. }
    lia.
  }
Qed.

Definition highest_bit x :=
  match x with
  | Zpos x => Z.of_nat (pos_highest_bit x O)
  | _ => 0%Z
  end.

Definition int_highest_bit (x : int) : Z :=
  highest_bit (Int.unsigned x).

Opaque Pos.pow Nat.pow.

Lemma int_highest_bit_range :
  forall x, 0 <= int_highest_bit x <= 31.
Proof.
  unfold int_highest_bit, highest_bit.
  destruct x; cbn.
  destruct intval; try lia.
  assert (RANGE : (p < Pos.of_nat (2^32)%nat)%positive).
  { rewrite Nat2Pos.inj_pow by lia.
    change Int.modulus with (2^32) in *.
    change (Pos.of_nat 2) with 2%positive.
    change (Pos.of_nat 32) with 32%positive.
    lia.
  }
  pose proof (pos_highest_bit_range p 32%nat RANGE O).
  lia.
Qed.

Fixpoint pos_lowest_bit (x : positive) (already : nat) : nat :=
  match x with
  | xH | xI _ => already
  | xO x' => pos_lowest_bit x' (S already)
  end.

Transparent Pos.pow Nat.pow.

Lemma pos_lowest_bit_range :
  forall x p (RANGE : (x < Pos.of_nat (2^p)%nat)%positive) already,
    (already <= pos_lowest_bit x already < already + p)%nat.
Proof.
  induction x; intros; cbn.
  all: destruct p; cbn in RANGE; [lia|idtac].
  1,3: lia.
  rewrite Nat.add_0_r in RANGE.
  replace (2^p + 2^p)%nat with (2 * (2^p))%nat in RANGE by lia.
  assert (S already <= pos_lowest_bit x (S already) < S already + p)%nat.
  { apply IHx. lia. }
  lia.
Qed.

Definition lowest_bit x :=
  match x with
  | Zpos x => Z.of_nat (pos_lowest_bit x O)
  | _ => 0%Z
  end.

Definition int_lowest_bit (x : int) : Z :=
  lowest_bit (Int.unsigned x).

Opaque Pos.pow Nat.pow.

Lemma int_lowest_bit_range :
  forall x, 0 <= int_lowest_bit x <= 31.
Proof.
  unfold int_lowest_bit, lowest_bit.
  destruct x; cbn.
  destruct intval; try lia.
  assert (RANGE : (p < Pos.of_nat (2^32)%nat)%positive).
  { rewrite Nat2Pos.inj_pow by lia.
    change Int.modulus with (2^32) in *.
    change (Pos.of_nat 2) with 2%positive.
    change (Pos.of_nat 32) with 32%positive.
    lia.
  }
  pose proof (pos_lowest_bit_range p 32%nat RANGE O).
  lia.
Qed.

Lemma pos_highest_bit_bigger:
  forall x a, (a <= pos_highest_bit x a)%nat.
Proof.
  induction x; cbn; intros.
  1,2: pose proof (IHx (S a)); lia.
  lia.
Qed.

Lemma pos_lowest_highest_bit:
  forall x a, (pos_lowest_bit x a <= pos_highest_bit x a)%nat.
Proof.
  induction x; cbn; intros.
  - pose proof (pos_highest_bit_bigger x (S a)). lia.
  - apply IHx.
  - lia.
Qed.

Lemma lowest_highest_bit:
  forall n, lowest_bit n <= highest_bit n.
Proof.
  destruct n; cbn; try lia.
  pose proof (pos_lowest_highest_bit p 0).
  lia.
Qed.

Lemma int_lowest_highest_bit:
  forall n, ExtValues.int_lowest_bit n <= ExtValues.int_highest_bit n.
Proof.
  intro. apply lowest_highest_bit.
Qed.

Definition is_bitfield lsb sz :=
  (Int.unsigned lsb) + (Int.unsigned sz) <=? 32.

Definition bitfield_mask lsb sz := Int.shl (Int.repr (Z.ones (Int.unsigned sz))) lsb.

Lemma zero_ext_shru_and_bitfield_mask_zero:
  forall lsb sz i,
    Int.ltu lsb Int.iwordsize = true ->
    Int.eq (Int.zero_ext (Int.unsigned sz) (Int.shru i lsb)) Int.zero =
    Int.eq (Int.and i (bitfield_mask lsb sz)) Int.zero.
Proof.
  intros lsb sz i Hlsb.
  apply Int.ltu_inv in Hlsb.
  rewrite Int.unsigned_repr_wordsize in Hlsb.
  destruct Hlsb as [Hlsb_lo Hlsb_hi].
  pose proof (Int.unsigned_range sz) as Hsz_range.
  set (LHS := Int.zero_ext (Int.unsigned sz) (Int.shru i lsb)).
  set (RHS := Int.and i (bitfield_mask lsb sz)).
  assert (Hequiv: LHS = Int.zero <-> RHS = Int.zero).
  { split.
    - intros HL.
      apply Int.same_bits_eq. intros n Hn. rewrite Int.bits_zero.
      unfold RHS. rewrite Int.bits_and by lia.
      unfold bitfield_mask. rewrite Int.bits_shl by lia.
      destruct (zlt n (Int.unsigned lsb)) as [|Hge].
      + apply Bool.andb_false_r.
      + rewrite Int.testbit_repr by lia. rewrite Z.testbit_ones by lia.
        destruct (Z.leb_spec 0 (n - Int.unsigned lsb)) as [Hle|Hlt]; try lia.
        destruct (Z.ltb_spec (n - Int.unsigned lsb) (Int.unsigned sz)) as [Hin|Hout].
        2: { apply Bool.andb_false_r. }
        rewrite Bool.andb_true_r.
        assert (HE: Int.testbit LHS (n - Int.unsigned lsb) = Int.testbit Int.zero (n - Int.unsigned lsb)).
        { rewrite HL. reflexivity. }
        rewrite Int.bits_zero in HE. unfold LHS in HE.
        rewrite Int.bits_zero_ext in HE by lia.
        rewrite zlt_true in HE by exact Hin.
        rewrite Int.bits_shru in HE by lia.
        rewrite zlt_true in HE by lia.
        replace (n - Int.unsigned lsb + Int.unsigned lsb) with n in HE by lia.
        exact HE.
    - intros HR.
      apply Int.same_bits_eq. intros n Hn. rewrite Int.bits_zero.
      unfold LHS. rewrite Int.bits_zero_ext by lia.
      destruct (zlt n (Int.unsigned sz)) as [Hlt|]; auto.
      rewrite Int.bits_shru by lia.
      destruct (zlt (n + Int.unsigned lsb) Int.zwordsize) as [Hlt2|]; auto.
      assert (HE: Int.testbit RHS (n + Int.unsigned lsb) = Int.testbit Int.zero (n + Int.unsigned lsb)).
      { rewrite HR. reflexivity. }
      rewrite Int.bits_zero in HE. unfold RHS in HE.
      rewrite Int.bits_and in HE by lia.
      unfold bitfield_mask in HE.
      rewrite Int.bits_shl in HE by lia.
      rewrite zlt_false in HE by lia.
      rewrite Int.testbit_repr in HE by lia.
      rewrite Z.testbit_ones in HE by lia.
      replace (n + Int.unsigned lsb - Int.unsigned lsb) with n in HE by lia.
      destruct (Z.leb_spec 0 n) as [|]; try lia.
      destruct (Z.ltb_spec n (Int.unsigned sz)) as [|]; try lia.
      simpl in HE. rewrite Bool.andb_true_r in HE. exact HE.
  }
  pose proof (Int.eq_spec LHS Int.zero) as ESL.
  pose proof (Int.eq_spec RHS Int.zero) as ESR.
  destruct (Int.eq LHS Int.zero); destruct (Int.eq RHS Int.zero); auto.
  - exfalso. apply ESR. apply Hequiv. exact ESL.
  - exfalso. apply ESL. apply Hequiv. exact ESR.
Qed.

Lemma sign_ext_eq_zero_iff_zero_ext_eq_zero:
  forall sz x,
    Int.eq (Int.sign_ext sz x) Int.zero =
    Int.eq (Int.zero_ext sz x) Int.zero.
Proof.
  intros sz x.
  pose proof (Int.eq_spec (Int.sign_ext sz x) Int.zero) as ES.
  pose proof (Int.eq_spec (Int.zero_ext sz x) Int.zero) as EZ.
  assert (Hequiv: Int.sign_ext sz x = Int.zero <-> Int.zero_ext sz x = Int.zero).
  { split.
    - intros H. apply Int.same_bits_eq. intros n Hn. rewrite Int.bits_zero.
      rewrite Int.bits_zero_ext by lia.
      destruct (zlt n sz) as [Hin|Hout]; auto.
      assert (HE: Int.testbit (Int.sign_ext sz x) n = Int.testbit Int.zero n)
        by (rewrite H; reflexivity).
      rewrite Int.bits_zero in HE.
      rewrite Int.bits_sign_ext in HE by lia.
      rewrite zlt_true in HE by lia. exact HE.
    - intros H. apply Int.same_bits_eq. intros n Hn. rewrite Int.bits_zero.
      rewrite Int.bits_sign_ext by lia.
      destruct (zlt n sz) as [Hin|Hout].
      + assert (HE: Int.testbit (Int.zero_ext sz x) n = Int.testbit Int.zero n)
          by (rewrite H; reflexivity).
        rewrite Int.bits_zero in HE.
        rewrite Int.bits_zero_ext in HE by lia.
        rewrite zlt_true in HE by lia. exact HE.
      + (* n >= sz, position is sz - 1; either out of range (returns false)
           or zero_ext gives same answer as testbit at sz-1 *)

        destruct (zlt 0 sz) as [Hpos|Hnonpos]; cycle 1.
        { (* sz <= 0: testbit at sz-1 < 0 is false *)
          unfold Int.testbit. apply Z.testbit_neg_r. lia. }
        (* 0 < sz, n >= sz, n < zwordsize ⇒ sz - 1 in [0, zwordsize) *)
        assert (Hszn : sz - 1 < Int.zwordsize) by lia.
        assert (HE: Int.testbit (Int.zero_ext sz x) (sz - 1) = Int.testbit Int.zero (sz - 1))
          by (rewrite H; reflexivity).
        rewrite Int.bits_zero in HE.
        rewrite Int.bits_zero_ext in HE by lia.
        rewrite zlt_true in HE by lia. exact HE.
  }
  destruct (Int.eq (Int.sign_ext sz x) Int.zero); destruct (Int.eq (Int.zero_ext sz x) Int.zero); auto.
  - exfalso. apply EZ. apply Hequiv. exact ES.
  - exfalso. apply ES. apply Hequiv. exact EZ.
Qed.

Lemma sign_ext_shru_and_bitfield_mask_zero:
  forall lsb sz i,
    Int.ltu lsb Int.iwordsize = true ->
    Int.eq (Int.sign_ext (Int.unsigned sz) (Int.shru i lsb)) Int.zero =
    Int.eq (Int.and i (bitfield_mask lsb sz)) Int.zero.
Proof.
  intros lsb sz i Hlsb.
  rewrite sign_ext_eq_zero_iff_zero_ext_eq_zero.
  apply zero_ext_shru_and_bitfield_mask_zero. exact Hlsb.
Qed.

Definition insf lsb sz prev fld :=
  let mask := Vint (bitfield_mask lsb sz) in
  if is_bitfield lsb sz
  then
    Val.or (Val.and prev (Val.notint mask))
           (Val.and (Val.shl fld (Vint lsb)) mask)
  else Vundef.

Definition clearf lsb sz prev :=
  let mask := Vint (bitfield_mask lsb sz) in
  if is_bitfield lsb sz
  then Val.and prev (Val.notint mask)
  else Vundef.

Definition zbitfield_mask zstop zstart :=
  Z.shiftl (Z.ones (Z.succ (zstop - zstart))) zstart.

bswap32 on values, lifted from Int.bswap32. Used as the semantics of Obswap32 and Prev so the lowering correctness is by construction; identical to the OR-of-shifts chain produced by the canonical C byte-swap idiom.

Definition val_bswap32 (v : val) : val :=
  match v with
  | Vint n => Vint (Int.bswap32 n)
  | _ => Vundef
  end.

Definition fast_isfinitef1 (x : val) :=
  let mask := Vint (Int.repr ExtFloats.isfinitef1_mask) in
  Val.cmp_bool Cne (Val.and (Val.bits_of_single x) mask) mask.

Lemma fast_isfinitef1_correct: forall x, fast_isfinitef1 x = Val.isfinitef x.
Proof.
  destruct x; trivial.
  cbn. f_equal. apply ExtFloats.fast_isfinitef1_correct.
Qed.

Definition fast_isfinitef2 (x : val) :=
  let mask := Vint (Int.repr ExtFloats.isfinitef2_mask) in
  Val.cmp_bool Cne (Val.zero_ext 8 (Val.shru (Val.bits_of_single x) (Vint (Int.repr 23)))) mask.

Lemma fast_isfinitef2_correct: forall x, fast_isfinitef2 x = Val.isfinitef x.
Proof.
  destruct x; trivial.
  unfold fast_isfinitef2. simpl.
  change (Int.ltu (Int.repr 23) Int.iwordsize) with true. simpl.
  f_equal. apply ExtFloats.fast_isfinitef2_correct.
Qed.

Definition fast_isfinite (x : val) :=
  let mask := Vint (Int.repr ExtFloats.isfinite_mask) in
  Val.cmp_bool Cne (Val.zero_ext 11 (Val.shru (Val.hiword (Val.bits_of_float x)) (Vint (Int.repr 20)))) mask.

Lemma fast_isfinite_correct: forall x, fast_isfinite x = Val.isfinite x.
Proof.
  destruct x; trivial.
  unfold fast_isfinite. simpl.
  change (Int.ltu (Int.repr 20) Int.iwordsize) with true. simpl.
  f_equal. apply ExtFloats.fast_isfinite_correct.
Qed.