{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module TensorFlow.Minimize
( Minimizer
, minimizeWith
, gradientDescent
, OneOfAdamDataTypes
, AdamConfig(..)
, adam
, adam'
) where
import Data.Complex (Complex)
import Data.Int (Int8,Int16,Int32,Int64)
import Data.Word (Word8,Word16,Word32,Word64)
import Control.Monad (zipWithM)
import Data.Default (Default(..))
import Data.List (zipWith4)
import Data.Maybe (fromMaybe)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (assign, initializedVariable)
import qualified TensorFlow.Variable as TF
type Minimizer a =
forall m. TF.MonadBuild m =>
[TF.Variable a] -> [TF.Tensor TF.Value a] -> m TF.ControlNode
minimizeWith :: (TF.MonadBuild m, TF.GradientCompatible a)
=> Minimizer a
-> TF.Tensor v a
-> [TF.Variable a]
-> m TF.ControlNode
minimizeWith :: Minimizer a -> Tensor v a -> [Variable a] -> m ControlNode
minimizeWith minimizer :: Minimizer a
minimizer loss :: Tensor v a
loss params :: [Variable a]
params =
Tensor v a -> [Variable a] -> m [Tensor Value a]
forall a (v1 :: * -> *) (t :: * -> *) (m :: * -> *).
(MonadBuild m, Rendered t, ToTensor t, GradientCompatible a) =>
Tensor v1 a -> [t a] -> m [Tensor Value a]
TF.gradients Tensor v a
loss [Variable a]
params m [Tensor Value a]
-> ([Tensor Value a] -> m ControlNode) -> m ControlNode
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Variable a] -> [Tensor Value a] -> m ControlNode
Minimizer a
minimizer [Variable a]
params
gradientDescent :: TF.GradientCompatible a
=> a
-> Minimizer a
gradientDescent :: a -> Minimizer a
gradientDescent learningRate :: a
learningRate params :: [Variable a]
params grads :: [Tensor Value a]
grads = Text -> m ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
TF.withNameScope "gradientDescent" (m ControlNode -> m ControlNode) -> m ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
let applyGrad :: Variable a -> Tensor v'2 a -> m ControlNode
applyGrad param :: Variable a
param grad :: Tensor v'2 a
grad =
Variable a -> Tensor Build a -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Variable a -> Tensor v a -> m ControlNode
TF.assignAdd Variable a
param (a -> Tensor Build a
forall a. TensorType a => a -> Tensor Build a
TF.scalar (-a
learningRate) Tensor Build a -> Tensor v'2 a -> Tensor Build a
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`TF.mul` Tensor v'2 a
grad)
[ControlNode] -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
TF.group ([ControlNode] -> m ControlNode)
-> m [ControlNode] -> m ControlNode
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Variable a -> Tensor Value a -> m ControlNode)
-> [Variable a] -> [Tensor Value a] -> m [ControlNode]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Variable a -> Tensor Value a -> m ControlNode
forall (m :: * -> *) (v'2 :: * -> *).
MonadBuild m =>
Variable a -> Tensor v'2 a -> m ControlNode
applyGrad [Variable a]
params [Tensor Value a]
grads
data AdamConfig t = AdamConfig
{ AdamConfig t -> t
adamLearningRate :: t
, AdamConfig t -> t
adamBeta1 :: t
, AdamConfig t -> t
adamBeta2 :: t
, AdamConfig t -> t
adamEpsilon :: t
}
type OneOfAdamDataTypes t =
TF.OneOf '[ Complex Double, Complex Float
, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
, Double, Float] t
instance Fractional t => Default (AdamConfig t) where
def :: AdamConfig t
def = t -> t -> t -> t -> AdamConfig t
forall t. t -> t -> t -> t -> AdamConfig t
AdamConfig 0.001 0.9 0.999 1e-8
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
adam :: Minimizer t
adam = AdamConfig t -> Minimizer t
forall t. OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
adam' AdamConfig t
forall a. Default a => a
def
adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
adam' :: AdamConfig t -> Minimizer t
adam' config :: AdamConfig t
config params :: [Variable t]
params grads :: [Tensor Value t]
grads = Text -> m ControlNode -> m ControlNode
forall (m :: * -> *) a. MonadBuild m => Text -> m a -> m a
TF.withNameScope "adam" (m ControlNode -> m ControlNode) -> m ControlNode -> m ControlNode
forall a b. (a -> b) -> a -> b
$ do
let lr :: Tensor Build t
lr = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamLearningRate AdamConfig t
config)
beta1 :: Tensor Build t
beta1 = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamBeta1 AdamConfig t
config)
beta2 :: Tensor Build t
beta2 = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamBeta2 AdamConfig t
config)
epsilon :: Tensor Build t
epsilon = t -> Tensor Build t
forall a. TensorType a => a -> Tensor Build a
TF.scalar (AdamConfig t -> t
forall t. AdamConfig t -> t
adamEpsilon AdamConfig t
config)
let errorMsg :: [Char]
errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
initVal :: Variable a -> Tensor Value a
initVal = Tensor Value a -> Maybe (Tensor Value a) -> Tensor Value a
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> Tensor Value a
forall a. HasCallStack => [Char] -> a
error [Char]
errorMsg) (Maybe (Tensor Value a) -> Tensor Value a)
-> (Variable a -> Maybe (Tensor Value a))
-> Variable a
-> Tensor Value a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> Maybe (Tensor Value a)
forall a. Variable a -> Maybe (Tensor Value a)
TF.initializedValue
[Variable t]
ms <- (Variable t -> m (Variable t)) -> [Variable t] -> m [Variable t]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable (Tensor Build t -> m (Variable t))
-> (Variable t -> Tensor Build t) -> Variable t -> m (Variable t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value t -> Tensor Build t
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
TF.zerosLike (Tensor Value t -> Tensor Build t)
-> (Variable t -> Tensor Value t) -> Variable t -> Tensor Build t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable t -> Tensor Value t
forall a. Variable a -> Tensor Value a
initVal) [Variable t]
params
[Variable t]
vs <- (Variable t -> m (Variable t)) -> [Variable t] -> m [Variable t]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable (Tensor Build t -> m (Variable t))
-> (Variable t -> Tensor Build t) -> Variable t -> m (Variable t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor Value t -> Tensor Build t
forall (v'1 :: * -> *) t.
TensorType t =>
Tensor v'1 t -> Tensor Build t
TF.zerosLike (Tensor Value t -> Tensor Build t)
-> (Variable t -> Tensor Value t) -> Variable t -> Tensor Build t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable t -> Tensor Value t
forall a. Variable a -> Tensor Value a
initVal) [Variable t]
params
Variable t
beta1Power <- Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable Tensor Build t
beta1
Variable t
beta2Power <- Tensor Build t -> m (Variable t)
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Tensor v a -> m (Variable a)
TF.initializedVariable Tensor Build t
beta2
let applyGrad :: Variable t
-> Variable t -> Variable t -> Tensor v7 t -> m ControlNode
applyGrad param :: Variable t
param m :: Variable t
m v :: Variable t
v =
Variable t
-> Variable t
-> Variable t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor Build t
-> Tensor v7 t
-> m ControlNode
forall (m :: * -> *) t (v1 :: * -> *) (v2 :: * -> *) (v3 :: * -> *)
(v4 :: * -> *) (v5 :: * -> *) (v6 :: * -> *) (v7 :: * -> *).
(MonadBuild m,
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t) =>
Variable t
-> Variable t
-> Variable t
-> Tensor v1 t
-> Tensor v2 t
-> Tensor v3 t
-> Tensor v4 t
-> Tensor v5 t
-> Tensor v6 t
-> Tensor v7 t
-> m ControlNode
TF.resourceApplyAdam Variable t
param Variable t
m Variable t
v
(Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
beta1Power)
(Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
beta2Power)
Tensor Build t
lr Tensor Build t
beta1 Tensor Build t
beta2 Tensor Build t
epsilon
[ControlNode]
updateVars <- [m ControlNode] -> m [ControlNode]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([m ControlNode] -> m [ControlNode])
-> [m ControlNode] -> m [ControlNode]
forall a b. (a -> b) -> a -> b
$ (Variable t
-> Variable t -> Variable t -> Tensor Value t -> m ControlNode)
-> [Variable t]
-> [Variable t]
-> [Variable t]
-> [Tensor Value t]
-> [m ControlNode]
forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 Variable t
-> Variable t -> Variable t -> Tensor Value t -> m ControlNode
forall (m :: * -> *) (v7 :: * -> *).
MonadBuild m =>
Variable t
-> Variable t -> Variable t -> Tensor v7 t -> m ControlNode
applyGrad [Variable t]
params [Variable t]
ms [Variable t]
vs [Tensor Value t]
grads
let updateBeta :: Variable t -> Tensor v'2 t -> m ControlNode
updateBeta betaPower :: Variable t
betaPower beta :: Tensor v'2 t
beta =
[ControlNode] -> m ControlNode -> m ControlNode
forall (m :: * -> *) t a.
(MonadBuild m, Nodes t) =>
t -> m a -> m a
TF.withControlDependencies [ControlNode]
updateVars
(Variable t -> Tensor Build t -> m ControlNode
forall (m :: * -> *) a (v :: * -> *).
(MonadBuild m, TensorType a) =>
Variable a -> Tensor v a -> m ControlNode
TF.assign Variable t
betaPower (Variable t -> Tensor Build t
forall a. TensorType a => Variable a -> Tensor Build a
TF.readValue Variable t
betaPower Tensor Build t -> Tensor v'2 t -> Tensor Build t
forall (v'1 :: * -> *) (v'2 :: * -> *) t.
OneOf
'[Complex Double, Complex Float, Int16, Int32, Int64, Int8, Word16,
Word8, Double, Float]
t =>
Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
`TF.mul` Tensor v'2 t
beta))
ControlNode
updateBeta1 <- Variable t -> Tensor Build t -> m ControlNode
forall t (m :: * -> *) (v'2 :: * -> *).
(t /= ByteString, t /= Bool, MonadBuild m, TensorType t) =>
Variable t -> Tensor v'2 t -> m ControlNode
updateBeta Variable t
beta1Power Tensor Build t
beta1
ControlNode
updateBeta2 <- Variable t -> Tensor Build t -> m ControlNode
forall t (m :: * -> *) (v'2 :: * -> *).
(t /= ByteString, t /= Bool, MonadBuild m, TensorType t) =>
Variable t -> Tensor v'2 t -> m ControlNode
updateBeta Variable t
beta2Power Tensor Build t
beta2
[ControlNode] -> m ControlNode
forall (m :: * -> *) t.
(MonadBuild m, Nodes t) =>
t -> m ControlNode
TF.group (ControlNode
updateBeta1ControlNode -> [ControlNode] -> [ControlNode]
forall a. a -> [a] -> [a]
:ControlNode
updateBeta2ControlNode -> [ControlNode] -> [ControlNode]
forall a. a -> [a] -> [a]
:[ControlNode]
updateVars)