\import Algebra.Semiring
\import Data.Array
\import Data.Array.EPerm
\import Data.Fin (fsuc, nat_fin_=)
\import Data.Or
\import Equiv
\import Function
\import Function.Meta
\import Logic
\import Meta
\import Paths
\import Paths.Meta
\func pairs {A B C : \Type} (f : A -> B -> C) (l : Array A) (l' : Array B) : Array C \elim l
| nil => nil
| a :: l => map (f a) l' ++ pairs f l l'
\where {
\open EPerm
\func pairs_nil {A B C : \Type} {f : A -> B -> C} {l : Array A} : pairs f l nil = nil \elim l
| nil => idp
| a :: l => pairs_nil
\func pairs-flip {A B C : \Type} {f : A -> B -> C} {l : Array A} {l' : Array B} : EPerm (pairs f l l') (pairs (\lam b a => f a b) l' l) \elim l, l'
| nil, nil => eperm-nil
| nil, a :: l => pairs-flip
| a :: l, nil => pairs-flip
| a :: l, b :: l' => eperm-:: idp $ eperm-++-right pairs-flip `eperm-trans` transport2 EPerm ++-assoc ++-assoc (eperm-++ eperm-++-comm $ eperm-sym pairs-flip) `eperm-trans` eperm-++-right pairs-flip
\func pairs_++-left {A B C : \Type} {f : A -> B -> C} {l1 l2 : Array A} {l : Array B} : pairs f (l1 ++ l2) l = pairs f l1 l ++ pairs f l2 l \elim l1
| nil => idp
| a :: l1 => pmap (_ ++) pairs_++-left *> inv ++-assoc
\func pairs_++-right {A B C : \Type} {f : A -> B -> C} {l : Array A} {l1 l2 : Array B} : EPerm (pairs f l (l1 ++ l2)) (pairs f l l1 ++ pairs f l l2)
=> pairs-flip `eperm-trans` rewrite pairs_++-left (eperm-++ pairs-flip pairs-flip)
\func pairs_map-left {A : \Type} {f : A -> A -> A} (p : \Pi {x y z : A} -> f (f x y) z = f x (f y z)) {a : A} {l l' : Array A}
: pairs f (map (f a) l) l' = map (f a) (pairs f l l') \elim l
| nil => idp
| a' :: l => pmap2 (++) (arrayExt {_} {_} {map (f (f a a')) l'} \lam j => p) (pairs_map-left p) *> inv (map_++ (f a))
\func pairs_map {A B : \Type} {f1 : A -> A -> A} {f2 : B -> B -> B} (g : A -> B)
(p : \Pi {a b : A} -> g (f1 a b) = f2 (g a) (g b)) {l l' : Array A}
: map g (pairs f1 l l') = pairs f2 (map g l) (map g l') \elim l
| nil => idp
| a :: l => map_++ g *> pmap2 (++) (exts \lam j => p) (pairs_map _ p)
\func pairs-assoc {A : \Type} {f : A -> A -> A} (p : \Pi {x y z : A} -> f (f x y) z = f x (f y z)) {l1 l2 l3 : Array A}
: pairs f (pairs f l1 l2) l3 = pairs f l1 (pairs f l2 l3) \elim l1
| nil => idp
| a :: l => pairs_++-left *> pmap2 (++) (pairs_map-left p) (pairs-assoc p)
\func pairs-index {A B C : \Type} {f : A -> B -> C} {l : Array A} {l' : Array B} (i : Fin l.len) (j : Fin l'.len)
: \Sigma (k : Fin (DArray.len {pairs f l l'})) (pairs f l l' k = f (l i) (l' j)) \elim l, i
| a :: l, 0 => (++.index-left {_} {map (f a) l'} j, ++.++_index-left {_} {map (f a) l'} j)
| a :: l, suc i => \have t => pairs-index i j
\in (++.index-right t.1, ++.++_index-right *> t.2)
\lemma pairs-index-inj {A B C : \Type} {f : A -> B -> C} {l : Array A} {l' : Array B} {i i' : Fin l.len} {j j' : Fin l'.len} (p : (pairs-index {_} {_} {_} {f} i j).1 = (pairs-index i' j').1) : \Sigma (i = i') (j = j') \elim l, i, i'
| a :: l, 0, 0 => (idp, nat_fin_= $ inv (later ++.index-left-nat) *> p *> ++.index-left-nat)
| a :: l, 0, suc i' => absurd (++.index-left/=right p)
| a :: l, suc i, 0 => absurd $ ++.index-left/=right (inv p)
| a :: l, suc i, suc i' => \let t => pairs-index-inj (++.index-right-inj p)
\in (pmap fsuc t.1, t.2)
\lemma pairs-index-surj {A B C : \Type} {f : A -> B -> C} {l : Array A} {l' : Array B}
: IsSurj (\lam (s : \Sigma (Fin l.len) (Fin l'.len)) => (pairs-index {_} {_} {_} {f} s.1 s.2).1) \elim l
| nil => \case __
| a :: l => \lam k => \case ++.split-index k \with {
| inl r => inP ((0, r.1), inv r.2)
| inr r => TruncP.map (pairs-index-surj r.1) \lam t => ((suc t.1.1, t.1.2), pmap ++.index-right t.2 *> inv r.2)
}
\lemma pairs-index-equiv {A B C : \Type} {f : A -> B -> C} {l : Array A} {l' : Array B}
: Equiv {\Sigma (Fin l.len) (Fin l'.len)} {Fin (DArray.len {pairs f l l'})} (\lam s => (pairs-index s.1 s.2).1)
=> Equiv.fromInjSurj _ (\lam p => ext (pairs-index-inj p)) pairs-index-surj
\lemma pairs-distr {A : \Type} {R : Semiring} {f : A -> A -> A} (g : A -> R) {l l' : Array A} (p : \Pi {a b : A} -> g (f a b) = g a R.* g b)
: R.BigSum (map g (pairs f l l')) = R.BigSum (map g l) R.* R.BigSum (map g l')
=> inv R.FinSum=BigSum *> R.FinSum_Equiv pairs.pairs-index-equiv *> pmap R.FinSum (ext \lam s => rewrite (pairs.pairs-index s.1 s.2).2 p) *> inv R.FinSum-distr *> pmap2 (R.*) R.FinSum=BigSum R.FinSum=BigSum
}