\import Algebra.Group
\import Algebra.Meta
\import Algebra.Module
\import Algebra.Module.LinearMap
\import Algebra.Monoid
\import Algebra.Monoid.Category
\import Algebra.Ring
\import Function.Meta
\import Logic
\import Meta
\import Paths
\import Paths.Meta
\open DRing
\open isDiff

\class DRing \extends CRing
  | inv-dense {B : \Set} (U : E -> \Prop) (f g : \Sigma (z : E) (U z) -> B) : (\Pi (p : \Sigma (z : E) (u : U z)) -> Monoid.Inv p.1 -> f p = g p) -> \Pi (p : \Sigma (z : E) (u : U z)) -> f p = g p
  \where
    \lemma cancel-lem {R : DRing} {B : LModule R} (U : R -> \Prop) (f g : \Sigma (z : R) (U z) -> B) (p : \Sigma (z : R) (u : U z)) (s : \Pi (p : \Sigma (z : R) (u : U z)) -> p.1 *c f p = p.1 *c g p) : f p = g p
      => R.inv-dense U f g (\lam p p-inv => B.cancel p-inv (s p)) p

\func isDiff {R : DRing} {A B : LModule R} {U : A -> \Prop} (f : \Sigma (x : A) (U x) -> B) : \Set
  => \Pi (x : \Sigma (a : A) (U a)) (a : A) (t : \Sigma (r : R) (U (x.1 + r *c a))) -> \Sigma (y : B) (t.1 *c y = f (x.1 + t.1 *c a, t.2) - f x)
  \where {
    \use \level levelProp (d d' : isDiff f) : d = d'
      => ext \lam x a t => ext $ cancel-lem _ _ _ t \lam t => (d x a t).2 *> inv (d' x a t).2

    \lemma isDiff-eq {f : \Sigma (x : A) (U x) -> B} (d : isDiff f) {x x' : \Sigma (a : A) (U a)} (p : x.1 = x'.1) {a : A} {t : \Sigma (r : R) (U (x.1 + r *c a))} {t' : \Sigma (r : R) (U (x'.1 + r *c a))} (q : t.1 = t'.1) : (d x a t).1 = (d x' a t').1
      => path (\lam i => (d (p i, prop-dpi _ _ _ i) a (q i, prop-dpi _ _ _ i)).1)
  }

\func isDiffT {R : DRing} {A B : LModule R} (f : A -> B) => isDiff (\lam (p : \Sigma A (\Sigma)) => f p.1)

-- | The derivative is linear
\func deriv {R : DRing} {A B : LModule R} {U : A -> \Prop} {f : \Sigma (x : A) (U x) -> B} (d : isDiff f) (x : \Sigma (a : A) (U a)) : LinearMap A B \cowith
  | func a => (d x a (0, sub-lem x)).1
  | func-+ {a} {a'} =>
    \have s => cancel-lem (\lam t => \Sigma (U (x.1 + t *c (a + a'))) (U (x.1 + t *c a)) (U (x.1 + t *c a + t *c a'))) (\lam t => (d x (a + a') (t.1, t.2.1)).1) (\lam t => (d (x.1 + t.1 *c a, t.2.2) a' (t.1, t.2.3)).1 + (d x a (t.1, t.2.2)).1) (0, (sub-lem x, sub-lem x, rewrite (A.*c_zro-left, zro-right) (sub-lem x)))
                \lam t => (d _ _ (_, _)).2 *> pmap (`- _) (pmap f (ext $ rewrite *c-ldistr $ inv +-assoc) *> inv zro-right *> pmap (_ +) (inv negative-left) *> inv +-assoc) *> +-assoc *> inv (pmap2 (+) (d _ _ (_, _)).2 (d _ _ (_, _)).2) *> inv B.*c-ldistr
    \in s *> +-comm *> pmap (_ +) (isDiff-eq d (pmap (x.1 +) A.*c_zro-left *> zro-right) idp)
  | func-*c {r} {a} =>
    \have s => cancel-lem (\lam t => \Sigma (U (x.1 + t *c (r *c a))) (U (x.1 + t * r *c a))) (\lam t => (d x (r *c a) (t.1, t.2.1)).1) (\lam t => r *c (d x a (t.1 * r, t.2.2)).1) (0, (sub-lem x, rewrite *c-assoc $ sub-lem x))
                \lam t => later (rewrite ((d x (r *c a) (t.1, t.2.1)).2, (d x a (t.1 * r, t.2.2)).2) $ pmap (f __ - f x) $ ext $ pmap (x.1 +) $ inv *c-assoc) *> *c-assoc
    \in s *> pmap (r *c) (isDiff-eq d idp R.zro_*-left)
  \where
    \lemma sub-lem {A : LModule} {U : A -> \Prop} {a : A} (x : \Sigma (a : A) (U a)) : U (x.1 + 0 *c a)
      => transportInv U (pmap (x.1 +) A.*c_zro-left *> zro-right) x.2

-- | The derivative of a linear map is the map itself
\func linear_diff {R : DRing} {A B : LModule R} (f : LinearMap A B) : isDiffT f
  => \lam x a t => (f a, rewrite (+-comm,f.func-+) $ inv (simplify f.func-*c))

-- | The derivative of a constant map is zero
\func const_diff {R : DRing} {A B : LModule R} (b : B) : isDiffT (\lam (_ : A) => b)
  => \lam x a t => (0, B.*c_zro-right *> inv negative-right)

-- | The derivative of a sum is the sum of derivatives
\func +_diff {R : DRing} {A B : LModule R} {U : A -> \Prop} {f g : \Sigma (a : A) (U a) -> B} (df : isDiff f) (dg : isDiff g) : isDiff (\lam x => f x + g x)
  => \lam x a t => ((df x a t).1 + (dg x a t).1, B.*c-ldistr *> pmap2 (+) (df x a t).2 (dg x a t).2 *> equation *> inv (pmap (_ +) AddGroup.negative_+))

{- | Differentiable maps are closed under composition.
 -
 -   The chain rule `deriv (o_diff df dg) x a = deriv dg (f x) (deriv df x a)` follows from this by `idp`.
 -}
\func o_diff {R : DRing} {A B C : LModule R} {U : A -> \Prop} {V : B -> \Prop}
             {f : \Sigma (a : A) (U a) -> \Sigma (b : B) (V b)} {g : \Sigma (b : B) (V b) -> C}
             (df : isDiff (f __).1) (dg : isDiff g) : isDiff (\lam x => g (f x))
  => \lam x a t => \have s => dg (f x) (df x a t).1 (t.1, rewrite (df x a t).2 $ transport V simplify (f (x.1 + t.1 *c a, t.2)).2)
                   \in (s.1, s.2 *> pmap (g __ - _) (ext $ rewrite (df x a t).2 $ +-comm *> simplify))

-- | Bilinear maps are differentiable
\func bilinear_diff {R : DRing} {A B C : LModule R} (m : BilinearMap A B C) : isDiffT {R} (\lam (p : \Sigma A B) => m p.1 p.2)
  => \lam (x,_) a (t,_) => (m x.1 a.2 + m a.1 x.2 + t *c m a.1 a.2, rewrite (m.linear-left.func-+, m.linear-right.func-+, m.linear-left.func-*c, m.linear-right.func-*c) $
      later (rewrite (negative-left,zro-left,*c-ldistr,*c-ldistr) $ +-assoc *> pmap (_ +) (pmap (_ + _ *c __) (inv m.linear-right.func-*c) *> inv *c-ldistr *> pmap (_ *c) (inv m.linear-right.func-+))) *> pmap (`+ _) +-assoc *> +-assoc *> +-comm)