Module MultiFixpoint


Require Import Maps.
Require Lattice.
Require Import Coqlib.
Require Import Iteration.
Require Import Classical.
Require Kildall.
Require List.
  
Module Solver (NS: Kildall.NODE_SET) (L: Lattice.SEMILATTICE_WITH_WIDENING).
  Module NodeMap := PMap.
  Definition node := positive.

  Definition nodemap_ge m1 m2 := forall n, L.ge m1!!n m2!!n.

  Lemma nodemap_ge_refl:
    forall m, nodemap_ge m m.
  Proof.
    unfold nodemap_ge; intros; apply L.ge_refl; apply L.eq_refl.
  Qed.

  Lemma nodemap_ge_trans:
    forall a b c, nodemap_ge a b -> nodemap_ge b c -> nodemap_ge a c.
  Proof.
    unfold nodemap_ge; intros.
    eapply L.ge_trans; eauto.
  Qed.
  
  Record state :=
    mkstate
      { state_map : NodeMap.t L.t;
        state_unstable : NS.t }.

  Section WIDEN_SELECT.
    Variable use_widening: node -> bool.
    Definition lub_or_widen n := if use_widening n then L.widen else L.lub.

    Lemma ge_lub_or_widen_left:
      forall n x y, L.ge (lub_or_widen n x y) x.
    Proof.
      unfold lub_or_widen. intros.
      destruct use_widening.
      - apply L.ge_widen_left.
      - apply L.ge_lub_left.
    Qed.

    Lemma ge_lub_or_widen_right:
      forall n x y, L.ge (lub_or_widen n x y) y.
    Proof.
      unfold lub_or_widen. intros.
      destruct use_widening.
      - apply L.ge_widen_right.
      - apply L.ge_lub_right.
    Qed.
    
  Definition update_once (s : state) (n : node) (update : L.t): state :=
    let map := s.(state_map) in
    let previous := NodeMap.get n map in
    let updated := lub_or_widen n previous update in
    let changed := negb (L.beq previous updated) in
    let map' := if changed
                then NodeMap.set n updated map
                else map in
    if changed
    then {| state_map := map';
            state_unstable := NS.add n s.(state_unstable) |}
    else s.

  Lemma update_once_unstable_remains:
    forall s n n' update
      (IN : NS.In n s.(state_unstable)),
      NS.In n (update_once s n' update).(state_unstable).
  Proof.
    intros.
    unfold update_once.
    pose proof (L.beq_correct (state_map s) !! n' (lub_or_widen n' (state_map s) !! n' update)) as BEQ.
    destruct L.beq; cbn in *.
    assumption.
    pose proof (NS.add_spec n' n (state_unstable s)) as ADD.
    apply ADD.
    auto.
  Qed.

  Definition increase_b n (a b : L.t) :=
    negb (L.beq a (lub_or_widen n a b)).

  Definition increase n a b := increase_b n a b = true.

  Lemma update_once_map_stable_remains:
    forall s n update n'
      (NOT_IN : ~NS.In n' (update_once s n update).(state_unstable)),
      (update_once s n update).(state_map) !! n' =
        s.(state_map) !! n'.
  Proof.
    intros.
    unfold update_once in *.
    destruct L.beq; cbn in *.
    reflexivity.
    destruct (peq n n').
    { subst n'.
      rewrite NS.add_spec in NOT_IN.
      exfalso. auto.
    }
    apply NodeMap.gso.
    congruence.
  Qed.
        
  Definition update_multiple (s0 : state)
             (updates : list (node * L.t)%type): state :=
    List.fold_left (fun s (update : (node * L.t)%type) =>
                      let (node, v) := update in
                      update_once s node v) updates s0.
  
  Lemma update_multiple_unstable_remains:
    forall updates s n
           (IN : NS.In n s.(state_unstable)),
      NS.In n (update_multiple s updates).(state_unstable).
  Proof.
    induction updates.
    { cbn. auto. }
    simpl update_multiple.
    destruct a as [n' update].
    intros.
    apply IHupdates.
    apply update_once_unstable_remains.
    assumption.
  Qed.
  
  Lemma update_multiple_map_stable_remains:
    forall updates s n'
      (NOT_IN : ~NS.In n' (update_multiple s updates).(state_unstable)),
      (update_multiple s updates).(state_map) !! n' =
        s.(state_map) !! n'.
  Proof.
    induction updates; intros.
    reflexivity.
    destruct a as [n update].
    simpl update_multiple in *.
    rewrite IHupdates by assumption.
    apply update_once_map_stable_remains.
    intro IN'.
    apply NOT_IN.
    apply update_multiple_unstable_remains.
    assumption.
  Qed.
  
  Definition initialize := update_multiple
                             {| state_map := NodeMap.init L.bot ;
                                state_unstable := NS.empty |}.

  Lemma Nge_widen:
    forall n a b c
           (NGE : ~ L.ge (lub_or_widen n a b) c),
      ~ L.ge a c.
  Proof.
    intros.
    intro GE.
    apply NGE.
    apply L.ge_trans with (y := a).
    - apply ge_lub_or_widen_left.
    - assumption.
  Qed.

  Lemma eq_widen_ge:
    forall n a b
           (EQ : L.eq a (lub_or_widen n a b)), (L.ge a b).
  Proof.
    intros.
    apply L.ge_trans with (y := lub_or_widen n a b).
    - apply L.ge_refl. assumption.
    - apply ge_lub_or_widen_right.
  Qed.
  
  Lemma update_once_greater2:
    forall s n update,
      (L.ge ((state_map (update_once s n update)) !! n) update).
  Proof.
    intros.
    unfold update_once.
    pose proof (L.beq_correct (state_map s) !! n (lub_or_widen n (state_map s) !! n update)) as CORRECT.
    destruct L.beq; cbn in *.
    { apply (eq_widen_ge n). auto. }
    rewrite NodeMap.gss.
    apply ge_lub_or_widen_right.
  Qed.

  Lemma update_once_greater1:
    forall s n update n',
      (L.ge ((state_map (update_once s n update)) !! n')
            ((state_map s) !! n')).
  Proof.
    intros.
    unfold update_once.
    pose proof (L.beq_correct (state_map s) !! n (lub_or_widen n (state_map s) !! n update)) as CORRECT.
    destruct L.beq; cbn in *.
    { apply L.ge_refl.
      apply L.eq_refl.
    }
    destruct (peq n n') as [EQ | NEQ].
    { subst n'.
      rewrite NodeMap.gss.
      apply ge_lub_or_widen_left.
    }
    rewrite NodeMap.gso by congruence.
    apply L.ge_refl.
    apply L.eq_refl.
  Qed.

  Lemma update_multiple_greater1:
    forall updates s n',
      (L.ge ((state_map (update_multiple s updates)) !! n')
            ((state_map s) !! n')).
  Proof.
    induction updates; intros.
    { cbn.
      apply L.ge_refl.
      apply L.eq_refl.
    }
    simpl update_multiple.
    destruct a as [node update].
    eapply L.ge_trans.
    - apply IHupdates.
    - apply update_once_greater1.
  Qed.

  Lemma update_multiple_greater2:
    forall updates s n update
           (IN : In (n, update) updates),
      (L.ge ((state_map (update_multiple s updates)) !! n)
            update).
  Proof.
    induction updates; intros.
    contradiction.
    destruct a as [n0 update0].
    destruct IN as [FIRST | OTHERS].
    { inv FIRST.
      simpl update_multiple.
      eapply L.ge_trans.
      - apply update_multiple_greater1.
      - apply update_once_greater2.
    }
    simpl update_multiple.
    apply IHupdates.
    assumption.
  Qed.
      
  Section STEP.

    Variable f : node -> L.t -> list (node * L.t)%type.
    Hypothesis f_strict : forall n,
        (f n L.bot) = nil.

    Definition stabilized_node (s : state) (n : node) :=
      forall (n' : node) (v' : L.t)
        (IN_UPDATE : In (n', v') (f n (NodeMap.get n s.(state_map))))
        (INCREASE : ~L.ge (NodeMap.get n' s.(state_map)) v'),
      NS.In n (s.(state_unstable)).
    
    Lemma update_once_stabilized_remains:
      forall s n nu update
        (STABLE : stabilized_node s n),
        stabilized_node (update_once s nu update) n.
    Proof.
      unfold stabilized_node, update_once.
      intros.
      pose proof (L.beq_correct (state_map s) !! nu (lub_or_widen nu (state_map s) !! nu update)) as EQZ.
      destruct L.beq; cbn in *.
      { pose proof (EQZ eq_refl) as EQ.
        clear EQZ.
        eauto.
      }
      clear EQZ.
      apply NS.add_spec.
      destruct (peq nu n) as [ EQ | nu_NEQ_n].
      { subst nu.
        auto.
      }
      right.
      rewrite NodeMap.gso in IN_UPDATE by congruence.
      destruct (peq nu n') as [ nu_EQ_n' | nu_NEQ_n' ].
      { subst nu.
        rewrite NodeMap.gss in INCREASE.
        apply (STABLE _ _ IN_UPDATE).
        eapply Nge_widen.
        eassumption.
      }
      rewrite NodeMap.gso in INCREASE by congruence.
      eauto.
    Qed.

    Lemma update_multiple_stabilized_remains:
      forall updates s n
        (STABLE : stabilized_node s n),
        stabilized_node (update_multiple s updates) n.
    Proof.
      induction updates; simpl update_multiple; intros.
      assumption.
      destruct a as [nu update].
      apply IHupdates.
      apply update_once_stabilized_remains.
      assumption.
    Qed.

    Lemma stabilize_node:
      forall s n,
      stabilized_node
        (update_multiple s (f n (state_map s) !! n)) n.
    Proof.
      unfold stabilized_node.
      intros.
      
      destruct (classic (NS.In n (state_unstable (update_multiple s (f n (state_map s) !! n))))) as [IN | NOT_IN].
      assumption.
      
      rewrite update_multiple_map_stable_remains in IN_UPDATE by assumption.
      exfalso.
      apply INCREASE.
      apply update_multiple_greater2.
      assumption.
    Qed.
    
    Definition stability_invariant (s : state) :=
      forall n, (stabilized_node s n).
  
    Lemma initialize_stable :
      forall l, (stability_invariant (initialize l)).
    Proof.
      unfold initialize, stability_invariant, stabilized_node in *.
      intros.
      destruct (classic (NS.In n
    (state_unstable
       (update_multiple
          {| state_map := NodeMap.init L.bot; state_unstable := NS.empty |} l)))) as [ IN | NOT_IN].
      assumption.
      rewrite update_multiple_map_stable_remains in IN_UPDATE by assumption.
      cbn in IN_UPDATE.
      rewrite NodeMap.gi in IN_UPDATE.
      rewrite f_strict in IN_UPDATE.
      contradiction.
    Qed.
    
    Definition step (s : state) : option state :=
      match NS.pick (state_unstable s) with
      | None => None
      | Some(node, unstable') =>
          Some (update_multiple {| state_map := s.(state_map);
                                  state_unstable := unstable' |}
                                (f node (NodeMap.get node s.(state_map))))
      end.

    Definition is_stable (m : NodeMap.t L.t) :=
      forall (n n': node) (v' : L.t)
        (IN_UPDATE : In (n', v') (f n (NodeMap.get n m))),
        L.ge (NodeMap.get n' m) v'.

    Lemma step_end : forall s
        (INVARIANT : stability_invariant s)
        (END : (step s) = None),
        is_stable (s.(state_map)).
    Proof.
      unfold step.
      intro.
      pose proof (NS.pick_none (state_unstable s)) as NONE.
      destruct NS.pick as [[n update] | ].
      { intro. discriminate. }
      intros.
      unfold is_stable, stability_invariant, stabilized_node, increase, increase_b in *.
      intros.
      destruct (classic (L.ge (state_map s) !! n' v')) as [GE | NGE].
      assumption.
      pose proof (INVARIANT _ _ _ IN_UPDATE NGE) as INCR.
      pose proof (L.beq_correct (state_map s) !! n' (L.lub (state_map s) !! n' v')) as EQ_CORRECT.
      destruct L.beq; cycle 1.
      { exfalso.
        apply (NONE n eq_refl).
        auto.
      }
      apply L.ge_trans with (y := (L.lub (state_map s) !! n' v')).
      { apply L.ge_refl.
        auto.
      }
      apply L.ge_lub_right.
    Qed.

    Lemma step_invariant:
      forall s s' (INVARIANT : (stability_invariant s))
                   (STEP : (step s) = (Some s')),
                   (stability_invariant s').
    Proof.
      unfold step, stability_invariant.
      intros.
      
      pose proof (NS.pick_some (state_unstable s)) as PICK_SOME.
      destruct (NS.pick (state_unstable s)).
      2: discriminate.
      destruct p as [picked_node after_picking].
      pose proof (PICK_SOME picked_node after_picking eq_refl) as PICKED.
      clear PICK_SOME.
      
      inv STEP.

      destruct (peq picked_node n) as [EQ | NEQ].
      { subst picked_node.
        replace (state_map s)
          with (state_map {| state_map := state_map s; state_unstable := after_picking |}) by reflexivity.
        apply stabilize_node.
      }

      assert (NS.In n (state_unstable s) <->
                NS.In n after_picking) as EQV.
      { rewrite (PICKED n).
        tauto.
      }
      
      apply update_multiple_stabilized_remains.
      unfold stabilized_node in *.
      cbn.
      intros.
      rewrite <- EQV.
      eapply INVARIANT; eassumption.
    Qed.

    Definition step_ret :=
      fun s =>
        match step s with
        | None => inl s.(state_map)
        | Some s' => inr s'
        end.
    
    Definition next_fixpoint :=
      PrimIter.iterate _ _ step_ret.

    Lemma next_fixpoint_solves :
      forall s (INVARIANT : stability_invariant s) m
        (SOLVED : (next_fixpoint s) = Some m),
        is_stable m.
    Proof.
      unfold next_fixpoint.
      intros.
      apply PrimIter.iterate_prop with (step := step_ret) (P := stability_invariant) (a := s) (2 := SOLVED).
      2: assumption.
      intros a INV.
      unfold step_ret.
      destruct (step a) as [s' | ] eqn:STEP.
      {
        pose proof (step_invariant a s').
        auto.
      }
      apply step_end; auto.
    Qed.

    Lemma step_ge:
      forall s s' (STEP : (step s) = (Some s')),
        (nodemap_ge s'.(state_map) s.(state_map)).
    Proof.
      intros.
      unfold step in *.
      pose proof (NS.pick_some (state_unstable s)) as PICK_SOME.
      destruct (NS.pick (state_unstable s)).
      2: discriminate.
      destruct p as [picked_node after_picking].
      pose proof (PICK_SOME picked_node after_picking eq_refl) as PICKED.
      clear PICK_SOME.
      
      inv STEP.
      unfold nodemap_ge.

      change (state_map s) with
        (state_map {| state_map := state_map s; state_unstable := after_picking |}).
      
      apply update_multiple_greater1.
   Qed.
              
    Lemma next_fixpoint_ge :
      forall s0 m
        (SOLVED : (next_fixpoint s0) = Some m),
        nodemap_ge m s0.(state_map).
    Proof.
      unfold next_fixpoint.
      intros until m. intro.
      apply PrimIter.iterate_prop with (step := step_ret)
        (P := fun s' => nodemap_ge s'.(state_map) s0.(state_map)) (a := s0)
        (2 := SOLVED).
      2: apply nodemap_ge_refl; fail.
      intros s GE.
      unfold step_ret.
      destruct (step s) as [s' | ] eqn:STEP.
      { eapply nodemap_ge_trans.
        apply (step_ge _ _ STEP).
        assumption.
      }
      assumption.
    Qed.
      
    Variable initial : list (node * L.t)%type.

    Definition solution_opt := next_fixpoint (initialize initial).

    Lemma solution_stable : forall solution,
        solution_opt = Some solution -> is_stable solution.
    Proof.
      unfold solution_opt.
      intros.
      apply next_fixpoint_solves with (s := initialize initial).
      apply initialize_stable.
      assumption.
    Qed.

    Lemma solution_includes_initial : forall solution node v
        (SOL : solution_opt = Some solution)
        (IN : In (node, v) initial),
        L.ge (solution !! node) v.
    Proof.
      unfold solution_opt.
      intros.
      pose proof (next_fixpoint_ge _ _ SOL) as GE.
      unfold nodemap_ge in *.
      eapply L.ge_trans.
      { apply GE. }
      unfold initialize.
      apply update_multiple_greater2.
      assumption.
    Qed.

    Section INVARIANT.
      Variable P : L.t -> Prop.
      Hypothesis P_bot : P L.bot.
      Hypothesis P_widen : forall n x y, P x -> P y -> P (lub_or_widen n x y).
      Hypothesis P_f : forall n v n' v',
          P v ->
          In (n', v') (f n v) ->
          P v'.
      Hypothesis P_initial : forall n v,
          In (n, v) initial -> P v.

      Definition P_state s :=
        forall n, P (PMap.get n s.(state_map)).

      Definition update_once_P:
        forall s n update
               (BEFORE : P_state s)
               (UPDATE : P update),
          (P_state (update_once s n update)).
      Proof.
        intros.
        unfold P_state, update_once in *.
        intro n0.
        destruct L.beq; cbn.
        { auto. }
        destruct (peq n n0).
        { subst n0.
          rewrite NodeMap.gss.
          auto.
        }
        rewrite NodeMap.gso by congruence.
        auto.
      Qed.

      Definition update_multiple_P:
        forall updates s
               (BEFORE : P_state s)
               (UPDATE : forall n' v', In (n', v') updates -> P v'),
          (P_state (update_multiple s updates)).
      Proof.
        induction updates; intros.
        assumption.
        destruct a as (n', v').
        simpl update_multiple.
        apply IHupdates.
        {
          apply update_once_P.
          assumption.
          { eapply UPDATE.
            cbn.
            left.
            reflexivity.
          }
        }
        intros.
        eapply UPDATE.
        cbn.
        right.
        eassumption.
      Qed.

      Lemma step_P:
        forall s (BEFORE : P_state s) s'
          (STEP : (step s) = Some s'),
          (P_state s').
      Proof.
        intros.
        unfold step in *.
        destruct NS.pick. 2: discriminate.
        destruct p as [n v].
        inv STEP.
        apply update_multiple_P.
        assumption.
        intros.
        eauto.
      Qed.

      Lemma next_fixpoint_P :
      forall s (INVARIANT : P_state s) m'
        (SOLVED : (next_fixpoint s) = Some m') n,
        P (PMap.get n m').
      Proof.
        intros.
        unfold next_fixpoint in *.
        apply (PrimIter.iterate_prop _ _ step_ret P_state
             (fun m => forall n, P (PMap.get n m))) with (a := s).
        2, 3: assumption.
        intros a Ha.
        unfold step_ret.
        pose proof (step_P a) as STEP.
        destruct step; auto.
      Qed.

    Lemma initialize_P:
      P_state (initialize initial).
    Proof.
      unfold initialize.
      apply update_multiple_P.
      { unfold P_state.
        cbn.
        intro.
        rewrite NodeMap.gi.
        apply P_bot.
      }
      assumption.
    Qed.
      
    Lemma solution_P : forall solution,
        solution_opt = Some solution ->
        forall n, P (PMap.get n solution).
    Proof.
      unfold solution_opt.
      intros.
      apply next_fixpoint_P with (s := initialize initial).
      apply initialize_P.
      assumption.
    Qed.
    End INVARIANT.
  End STEP.
  End WIDEN_SELECT.
End Solver.