\import Algebra.Domain
\import Algebra.Domain.Euclidean
\import Algebra.Monoid
\import Algebra.Monoid.Prime
\import Algebra.Ordered
\import Arith.Int
\import Arith.Nat
\import Data.Bool
\import Logic
\import Meta
\import Order.LinearOrder
\import Order.PartialOrder
\import Order.StrictOrder
\import Paths
\import Paths.Meta
\import Set
\open Nat(mod,-)
\open Monoid \hiding (equals)

-- A few characterizations of primality for natural numbers

\lemma nat_irr-isPrime (p : Irr {NatSemiring}) : Prime p p.notInv \cowith
  | isCancelable-left => p.isCancelable-left
  | isPrime {x} {y} p|x*y =>
    \let irr => \new Irr (pos p) (\lam p-inv => p.notInv (inv_int_nat p-inv)) (\lam {x} {y} p=x*y => \case p.isIrr {iabs x} {iabs y} (pmap iabs p=x*y *> iabs_*) \with {
      | byLeft x-inv => byLeft (inv_abs x-inv)
      | byRight y-inv => byRight (inv_abs y-inv)
    }) (\lam {_} {_} => Domain.nonZero-cancel-left (\lam p=0 => p.notIdemp (transportInv (\lam e => e = e * e) (pmap iabs p=0) idp)))
    \in \case Prime.isPrime {irr-isPrime {IntDomain} irr} {x} {y} (ldiv_nat_int p|x*y) \with {
          | byLeft p|x => byLeft (ldiv_int_nat p|x)
          | byRight p|y => byRight (ldiv_int_nat p|y)
        }

\lemma prime-div {n : Nat} : Prime n = (\Sigma (n /= 1) (\Pi {k : Nat} -> LDiv k n -> (k = n) || (k = 1)))
  => propExt (\lam (p : Prime n) => (\lam n=1 => p.notInv (rewrite n=1 Inv.ide-isInv), dir p)) (\lam t => conv t.1 t.2)
  \where {
    \lemma dir {n : Nat} (p : Irr n) {k : Nat} (k|n : LDiv k n) : (k = n) || (k = 1)
      => \case p.isIrr (inv k|n.inv-right) \with {
        | byLeft (s : Inv k) => byRight (natUnit s.inv-left)
        | byRight (s : Inv k|n.inv) => byLeft (inv (pmap (k *) (natUnit s.inv-left)) *> k|n.inv-right)
      }

    \lemma conv {n : Nat} (n/=1 : n /= 1) (f : \Pi {k : Nat} -> LDiv k n -> (k = n) || (k = 1)) : Prime n
      => \have n/=0 : n /= 0 => \lam n=0 => \case f {2} (\new LDiv { | inv => 0 | inv-right => inv n=0 }) \with {
           | byLeft n=2 => \case n=2 *> n=0
           | byRight ()
         }
         \in nat_irr-isPrime (\new Irr n {
           | notInv (n-inv : Inv) => n/=1 (natUnit n-inv.inv-left)
           | isIrr {x} {y} n=x*y => \case f {x} (\new LDiv { | inv => y | inv-right => inv n=x*y }) \with {
             | byLeft x=n => byRight (rewriteI (NatSemiring.cancel_*-left n/=0 (n=x*y *> pmap (`* _) x=n)) Inv.ide-isInv)
             | byRight x=1 => byLeft (rewrite x=1 Inv.ide-isInv)
           }
           | isCancelable-left => NatSemiring.cancel_*-left n/=0
         })
  }

\lemma prime-less {n : Nat} : Prime n = (\Sigma (n > 1) (\Pi {k : Nat} -> k < n -> LDiv k n -> k = 1))
  => propExt (\lam p => (\case \elim n, p \with {
    | 0, p : Prime 0 => \case p.isCancelable-left {0} {1} idp
    | 1, p : Prime 1 => absurd (p.notInv Inv.ide-isInv)
    | suc (suc n), _ => NatSemiring.suc<suc NatSemiring.zero<suc
  }, dir p)) (\lam t => conv t.1 t.2)
  \where {
    \lemma dir {n : Nat} (p : Irr n) {k : Nat} (k<n : k < n) (k|n : LDiv k n) : k = 1
      => \case prime-div.dir p k|n \with {
           | byLeft k=n => absurd (<-irreflexive (transportInv (k <) k=n k<n))
           | byRight k=1 => k=1
         }

    \lemma conv {n : Nat} (n>1 : n > 1) (f : \Pi {k : Nat} -> k < n -> LDiv k n -> k = 1) : Prime n
      => prime-div.conv (\lam n=1 => <-irreflexive (transportInv (`< n) n=1 n>1)) (\lam {k} k|n => \case LinearOrder.trichotomy k n \with {
        | less k<n => byRight (f k<n k|n)
        | equals k=n => byLeft k=n
        | greater k>n => absurd (ldiv_<= (\lam n=0 => \case transport (1 <) n=0 n>1) k|n k>n)
      })
  }

-- A primality test

\func isPrime (n : Nat) : Bool
  | 0 => false
  | 1 => false
  | 2 => true
  | n => isPositive (n mod 2) and rec n n 3
  \where {
    \func isPositive (n : Nat) : Bool
      | 0 => false
      | _ => true

    \func \infix 4 < (n m : Nat) : Bool => \case n - m \with {
      | pos _ => false
      | neg (suc _) => true
    }

    \func rec (n c j : Nat) : Bool \elim c
      | 0 => true
      | suc c => n < j * j or isPositive (n mod j) and rec n c (suc (suc j))
  }

\open isPrime \hiding (<)

\lemma isPrime=>prime {n : Nat} (p : isPrime n = true) : Prime n \elim n, p
  | 2, _ => prime-less.conv (suc<suc zero<suc) (\lam {k} k<2 k|2 => \case \elim k, k<2, k|2 \with {
    | 0, _, q : LDiv => \case *-comm *> q.inv-right
    | 1, _, _ => idp
    | suc (suc _), suc<suc (suc<suc ()), _
  })
  | suc (suc (suc _)) \as n, p => prime-lem (suc<suc zero<suc) (mod_div-lem (and.toSigma p).1) (rec-lem {n} {n} {3} (and.toSigma p).2 () (<_+-left n {0} {3} zero<suc))
  \where {
    \open NatSemiring

    \lemma mod_div-lem {k n : Nat} (p : isPositive (n mod k) = true) : Not (LDiv k n)
      => \lam k|n => \case rewrite (div_mod k|n) in p

    \func isOdd (n : Nat) : \Prop \elim n
      | 0 => Empty
      | 1 => \Sigma
      | suc (suc n) => isOdd n

    \lemma square_<=-lem {k : Nat} : k <= k * k \elim k
      | 0 => zero<=_
      | suc k => <=_* <=-refl (suc<=suc zero<=_)

    \lemma odd_suc-lem {j : Nat} (j-odd : isOdd j) (sj-odd : isOdd (suc j)) : Empty \elim j
      | 0 => j-odd
      | 1 => sj-odd
      | suc (suc j) => odd_suc-lem j-odd sj-odd

    \lemma odd-lem {j k : Nat} (j-odd : isOdd j) (k-odd : isOdd k) (j<=k : j <= k) (j/=k : j /= k) : suc (suc j) <= k
      => \case k -' j \as t, <=_exists j<=k : j + t = k \with {
           | 0, j=k => absurd (j/=k j=k)
           | 1, p => absurd (odd_suc-lem j-odd (transportInv isOdd p k-odd))
           | suc (suc t), p => transport (suc (suc j) <=) p (suc<=suc (suc<=suc (LinearlyOrderedSemiring.<=_+ <=-refl zero<=_)))
         }

    \lemma rec-lem {n c j : Nat} (p : rec n c j = true) (j-odd : isOdd j) (s : n < c + j) {k : Nat} (k-odd : isOdd k) (j<=k : j <= k) (k*k<=n : k * k <= n) (k|n : LDiv k n) : Empty \elim c
      | 0 => <-irreflexive (s <∘l j<=k <∘l square_<=-lem <∘l k*k<=n)
      | suc c => \case or.toOr p \with {
        | byLeft t => mcases {_} {arg addPath} t _ \with {
          | pos _, _, t' => \case t'
          | neg (suc n1), q, _ => <-irreflexive (triLess q <∘l <=_* j<=k j<=k <∘l k*k<=n)
        }
        | byRight t => \case decideEq j k \with {
          | yes j=k => mod_div-lem (and.toSigma t).1 (transportInv (LDiv __ n) j=k k|n)
          | no j/=k => rec-lem (and.toSigma t).2 j-odd (s <∘l <=-less id<suc) k-odd (odd-lem j-odd k-odd j<=k j/=k) k*k<=n k|n
        }
      }

    \lemma oddOrEven (n : Nat) : isOdd n || LDiv 2 n
      | 0 => byRight (\new LDiv { | inv => 0 | inv-right => idp })
      | 1 => byLeft ()
      | suc (suc n) => \case oddOrEven n \with {
        | byLeft r => byLeft r
        | byRight (d : LDiv 2 n) => byRight (\new LDiv {
          | inv => suc d.inv
          | inv-right => pmap (\lam t => suc (suc t)) d.inv-right
        })
      }

    \lemma prime-lem {n : Nat} (n>1 : n > 1) (div2 : Not (LDiv 2 n)) (divOdd : \Pi {k : Nat} -> isOdd k -> 3 <= k -> k * k <= n -> Not (LDiv k n)) : Prime n
      => prime-lem2 n>1 (\lam {k} => \case \elim k, oddOrEven k \with {
        | 0, byLeft ()
        | 1, byLeft odd => \lam _ _ => idp
        | 2, byLeft ()
        | suc (suc (suc _)) \as k, byLeft odd => \lam p k|n => absurd (divOdd {k} odd (suc<=suc (suc<=suc (suc<=suc zero<=_))) p k|n)
        | k, byRight even => \lam _ k|n => absurd (div2 (LDiv.trans even k|n))
      })

    \lemma prime-lem2 {n : Nat} (n>1 : n > 1) (f : \Pi {j : Nat} -> j * j <= n -> LDiv j n -> j = 1) : Prime n
      => prime-less.conv n>1 (\lam {k} k<n (k|n : LDiv k n) => \case totality k k|n.inv \with {
        | byLeft k<=i => f (transport (k * k <=) k|n.inv-right (<=_* <=-refl k<=i)) k|n
        | byRight i<=k =>
          \have i=1 => f (transport (k|n.inv * k|n.inv <=) k|n.inv-right (<=_* i<=k <=-refl)) (\new LDiv k|n.inv n k (*-comm *> k|n.inv-right))
          \in absurd (<-irreflexive (transport (`< n) (pmap (_ *) (inv i=1) *> k|n.inv-right) k<n))
      })
  }

\lemma prime=>isPrime {n : Nat} (p : Prime n) : isPrime n = true \elim n
  | 0 => \case p.isCancelable-left {0} {1} idp
  | 1 => absurd (p.notInv Inv.ide-isInv)
  | 2 => idp
  | suc (suc (suc _)) \as n => and.fromSigma (mod-lem p (suc<suc (suc<suc zero<suc)) (\case __), rec-lem {n} {n} p (suc<suc zero<suc))
  \where {
    \open NatSemiring

    \lemma mod-lem {k n : Nat} (p : Prime n) (k<n : k < n) (k/=1 : k /= 1) : isPositive (n mod k) = true
      => mcases {isPositive} {arg addPath} \with {
        | 0, q => absurd (k/=1 (prime-less.dir p k<n (mod_div q)))
        | suc _, _ => idp
      }

    \lemma <-lem {n m : Nat} (n<m : n < m) : n isPrime.< m = true => unfold (isPrime.<) (mcases {_} {arg addPath} \with {
      | 0, p => absurd (<-irreflexive (transportInv (n <) (triEquals p) n<m))
      | pos (suc k), p => absurd (<-irreflexive (<-transitive n<m (triGreater p)))
      | neg (suc k), _ => idp
    })

    \lemma square_<-lem {j : Nat} (j>1 : j > 1) : suc j < j * j \elim j, j>1
      | 1, suc<suc ()
      | suc (suc j), _ => <_+-left (j + 3) zero<suc

    \lemma rec-lem {n c j : Nat} (p : Prime n) (j>1 : j > 1) : rec n c j = true \elim c
      | 0 => idp
      | suc c => or.fromOr (\case totality (suc j) n \with {
        | byLeft j+1<=n => byRight (and.fromSigma (mod-lem p (id<suc <∘l j+1<=n) (\case rewrite __ in j>1 \with { | suc<suc () }), rec-lem p (<-transitive j>1 (<-transitive id<suc id<suc))))
        | byRight n<=j+1 => byLeft (<-lem (n<=j+1 <∘r square_<-lem j>1))
      })
  }

\func prime-isDec {n : Nat} : Dec (Prime n)
  => \case isPrime n \as b, idp : isPrime n = b \with {
       | true, p => yes (isPrime=>prime p)
       | false, p => no (\lam c => \case inv (prime=>isPrime c) *> p)
     }