-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module TensorFlow.Minimize
    ( Minimizer
    , minimizeWith
    , gradientDescent
    , OneOfAdamDataTypes
    , AdamConfig(..)
    , adam
    , adam'
    ) where

import Data.Complex (Complex)
import Data.Int (Int8,Int16,Int32,Int64)
import Data.Word (Word8,Word16,Word32,Word64)
import Control.Monad (zipWithM)
import Data.Default (Default(..))
import Data.List (zipWith4)
import Data.Maybe (fromMaybe)

import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (assign, initializedVariable)
import qualified TensorFlow.Variable as TF

-- | Functions that minimize a loss w.r.t. a set of 'TF.Variable's.
--
-- Generally only performs one step of an iterative algorithm.
--
-- 'Minimizer's are defined as a function of the gradients instead of
-- the loss so that users can apply transformations to the gradients.
type Minimizer a =
    forall m. TF.MonadBuild m =>
    [TF.Variable a] -> [TF.Tensor TF.Value a] -> m TF.ControlNode

-- | Convenience wrapper around 'TF.gradients' and a 'Minimizer'.
minimizeWith :: (TF.MonadBuild m, TF.GradientCompatible a)
             => Minimizer a
             -> TF.Tensor v a    -- ^ Loss.
             -> [TF.Variable a]  -- ^ Parameters of the loss function.
             -> m TF.ControlNode
minimizeWith :: Minimizer a -> Tensor v a -> [Variable a] -> m ControlNode
minimizeWith minimizer :: Minimizer a
minimizer loss :: Tensor v a
loss params :: [Variable a]
params =
    Tensor v a -> [Variable a] -> m [Tensor Value a]
forall a (v1 :: * -> *) (t :: * -> *) (m :: * -> *).
(MonadBuild m, Rendered t, ToTensor t, GradientCompatible a) =>
Tensor v1 a -> [t a] -> m [Tensor Value a]
TF.gradients Tensor v a
loss [Variable a]
params m [Tensor Value a]
-> ([Tensor Value a] -> m ControlNode) -> m ControlNode
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Variable a] -> [Tensor Value a] -> m ControlNode
Minimizer a
minimizer [Variable a]
params

-- | Perform one step of the gradient descent algorithm.
gradientDescent :: TF.GradientCompatible a
                => a  -- ^ Learning rate.
                -> Minimizer a
gradientDescent :: a -> Minimizer a
gradientDescent learningRate :: a
learningRate params :: [Variable a]
params grads :: [Tensor Value a]
grads = Text -> m ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
TF.withNameScope "gradientDescent" (m ControlNode -> m ControlNode) -> m ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
    let applyGrad :: Variable a -> Tensor v'2 a -> m ControlNode
applyGrad param :: Variable a
param grad :: Tensor v'2 a
grad =
            Variable a -> Tensor Build a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Variable a -> Tensor v a -> m ControlNode
TF.assignAdd Variable a
param (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
TF.scalar (-a
learningRate) Tensor Build a -> Tensor v'2 a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`TF.mul` Tensor v'2 a
grad)
    [ControlNode] -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
TF.group ([ControlNode] -> m ControlNode)
-> m [ControlNode] -> m ControlNode
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Variable a -> Tensor Value a -> m ControlNode)
-> [Variable a] -> [Tensor Value a] -> m [ControlNode]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Variable a -> Tensor Value a -> m ControlNode
forall (m :: * -> *) (v'2 :: * -> *).
MonadBuild m =>
Variable a -> Tensor v'2 a -> m ControlNode
applyGrad [Variable a]
params [Tensor Value a]
grads

data AdamConfig t = AdamConfig
    { AdamConfig t -> t
adamLearningRate :: t
    , AdamConfig t -> t
adamBeta1        :: t
    , AdamConfig t -> t
adamBeta2        :: t
    , AdamConfig t -> t
adamEpsilon      :: t
    }

type OneOfAdamDataTypes t =
        TF.OneOf '[ Complex Double, Complex Float
                  , Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
                  , Double, Float] t

instance Fractional t => Default (AdamConfig t) where
  -- Recommended defaults from the adam paper.
  def :: AdamConfig t
def = t -> t -> t -> t -> AdamConfig t
forall t. t -> t -> t -> t -> AdamConfig t
AdamConfig 0.001 0.9 0.999 1e-8

-- | Perform one step of the adam algorithm.
--
-- See https://arxiv.org/abs/1412.6980.
--
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
adam :: Minimizer t
adam = AdamConfig t -> Minimizer t
forall t. OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
adam' AdamConfig t
forall a. Default a => a
def

adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
adam' :: AdamConfig t -> Minimizer t
adam' config :: AdamConfig t
config params :: [Variable t]
params grads :: [Tensor Value t]
grads = Text -> m ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
TF.withNameScope "adam" (m ControlNode -> m ControlNode) -> m ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
    let lr :: Tensor Build t
lr = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamLearningRate AdamConfig t
config)
        beta1 :: Tensor Build t
beta1 = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamBeta1 AdamConfig t
config)
        beta2 :: Tensor Build t
beta2 = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamBeta2 AdamConfig t
config)
        epsilon :: Tensor Build t
epsilon = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamEpsilon AdamConfig t
config)
    -- Create adam state variables.
    let errorMsg :: [Char]
errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
        initVal :: Variable a -> Tensor Value a
initVal = Tensor Value a -> Maybe (Tensor Value a) -> Tensor Value a
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Tensor Value a
forall a. HasCallStack => [Char] -> a
error [Char]
errorMsg) (Maybe (Tensor Value a) -> Tensor Value a)
-> (Variable a -> Maybe (Tensor Value a))
-> Variable a
-> Tensor Value a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> Maybe (Tensor Value a)
forall a. Variable a -> Maybe (Tensor Value a)
TF.initializedValue
    [Variable t]
ms <- (Variable t -> m (Variable t)) -> [Variable t] -> m [Variable t]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable (Tensor Build t -> m (Variable t))
-> (Variable t -> Tensor Build t) -> Variable t -> m (Variable t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value t -> Tensor Build t
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
TF.zerosLike (Tensor Value t -> Tensor Build t)
-> (Variable t -> Tensor Value t) -> Variable t -> Tensor Build t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable t -> Tensor Value t
forall a. Variable a -> Tensor Value a
initVal) [Variable t]
params
    [Variable t]
vs <- (Variable t -> m (Variable t)) -> [Variable t] -> m [Variable t]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable (Tensor Build t -> m (Variable t))
-> (Variable t -> Tensor Build t) -> Variable t -> m (Variable t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value t -> Tensor Build t
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
TF.zerosLike (Tensor Value t -> Tensor Build t)
-> (Variable t -> Tensor Value t) -> Variable t -> Tensor Build t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable t -> Tensor Value t
forall a. Variable a -> Tensor Value a
initVal) [Variable t]
params
    Variable t
beta1Power <- Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable Tensor Build t
beta1
    Variable t
beta2Power <- Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable Tensor Build t
beta2
    -- Perform adam update.
    let applyGrad :: Variable t
-> Variable t -> Variable t -> Tensor v7 t -> m ControlNode
applyGrad param :: Variable t
param m :: Variable t
m v :: Variable t
v =
            Variable t
-> Variable t
-> Variable t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor v7 t
-> m ControlNode
forall (m :: * -> *) t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *)
       (v4 :: * -> *) (v5 :: * -> *) (v6 :: * -> *) (v7 :: * -> *).
(MonadBuild m,
 OneOf
   '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
     Word8, Double, Float]
   t) =>
Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
TF.resourceApplyAdam Variable t
param Variable t
m Variable t
v
                                 (Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
beta1Power)
                                 (Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
beta2Power)
                                 Tensor Build t
lr Tensor Build t
beta1 Tensor Build t
beta2 Tensor Build t
epsilon
    [ControlNode]
updateVars <- [m ControlNode] -> m [ControlNode]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m ControlNode] -> m [ControlNode])
-> [m ControlNode] -> m [ControlNode]
forall a b. (a -> b) -> a -> b
$ (Variable t
 -> Variable t -> Variable t -> Tensor Value t -> m ControlNode)
-> [Variable t]
-> [Variable t]
-> [Variable t]
-> [Tensor Value t]
-> [m ControlNode]
forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 Variable t
-> Variable t -> Variable t -> Tensor Value t -> m ControlNode
forall (m :: * -> *) (v7 :: * -> *).
MonadBuild m =>
Variable t
-> Variable t -> Variable t -> Tensor v7 t -> m ControlNode
applyGrad [Variable t]
params [Variable t]
ms [Variable t]
vs [Tensor Value t]
grads
    -- Update beta variables after adam update.
    let updateBeta :: Variable t -> Tensor v'2 t -> m ControlNode
updateBeta betaPower :: Variable t
betaPower beta :: Tensor v'2 t
beta =
            [ControlNode] -> m ControlNode -> m ControlNode
forall (m :: * -> *) t a.
(MonadBuild m, Nodes t) =>
t -> m a -> m a
TF.withControlDependencies [ControlNode]
updateVars
                (Variable t -> Tensor Build t -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Variable a -> Tensor v a -> m ControlNode
TF.assign Variable t
betaPower (Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
betaPower Tensor Build t -> Tensor v'2 t -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
  '[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
    Word8, Double, Float]
  t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`TF.mul` Tensor v'2 t
beta))
    ControlNode
updateBeta1 <- Variable t -> Tensor Build t -> m ControlNode
forall t (m :: * -> *) (v'2 :: * -> *).
(t /= ByteString, t /= Bool, MonadBuild m, TensorType t) =>
Variable t -> Tensor v'2 t -> m ControlNode
updateBeta Variable t
beta1Power Tensor Build t
beta1
    ControlNode
updateBeta2 <- Variable t -> Tensor Build t -> m ControlNode
forall t (m :: * -> *) (v'2 :: * -> *).
(t /= ByteString, t /= Bool, MonadBuild m, TensorType t) =>
Variable t -> Tensor v'2 t -> m ControlNode
updateBeta Variable t
beta2Power Tensor Build t
beta2
    [ControlNode] -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
TF.group (ControlNode
updateBeta1ControlNode -> [ControlNode] -> [ControlNode]
forall a. a -> [a] -> [a]
:ControlNode
updateBeta2ControlNode -> [ControlNode] -> [ControlNode]
forall a. a -> [a] -> [a]
:[ControlNode]
updateVars)