{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module TensorFlow.Minimize
( Minimizer
, minimizeWith
, gradientDescent
, AdamConfig(..)
, adam
, adam'
) where
import Control.Monad (zipWithM)
import Data.Default (Default(..))
import Data.List (zipWith4)
import Data.Maybe (fromMaybe)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (assign, initializedVariable)
import qualified TensorFlow.Variable as TF
type Minimizer a =
forall m. TF.MonadBuild m =>
[TF.Variable a] -> [TF.Tensor TF.Value a] -> m TF.ControlNode
minimizeWith :: (TF.MonadBuild m, TF.GradientCompatible a)
=> Minimizer a
-> TF.Tensor v a
-> [TF.Variable a]
-> m TF.ControlNode
minimizeWith minimizer loss params =
TF.gradients loss params >>= minimizer params
gradientDescent :: TF.GradientCompatible a
=> a
-> Minimizer a
gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $ do
let applyGrad param grad =
TF.assignAdd param (TF.scalar (-learningRate) `TF.mul` grad)
TF.group =<< zipWithM applyGrad params grads
data AdamConfig = AdamConfig
{ adamLearningRate :: Float
, adamBeta1 :: Float
, adamBeta2 :: Float
, adamEpsilon :: Float
}
instance Default AdamConfig where
def = AdamConfig 0.001 0.9 0.999 1e-8
adam :: Minimizer Float
adam = adam' def
adam' :: AdamConfig -> Minimizer Float
adam' config params grads = TF.withNameScope "adam" $ do
let lr = TF.scalar (adamLearningRate config)
beta1 = TF.scalar (adamBeta1 config)
beta2 = TF.scalar (adamBeta2 config)
epsilon = TF.scalar (adamEpsilon config)
let errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
initVal = fromMaybe (error errorMsg) . TF.initializedValue
ms <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
vs <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
beta1Power <- TF.initializedVariable beta1
beta2Power <- TF.initializedVariable beta2
let applyGrad param m v =
TF.resourceApplyAdam param m v
(TF.readValue beta1Power)
(TF.readValue beta2Power)
lr beta1 beta2 epsilon
updateVars <- sequence $ zipWith4 applyGrad params ms vs grads
let updateBeta betaPower beta =
TF.withControlDependencies updateVars
(TF.assign betaPower (TF.readValue betaPower `TF.mul` beta))
updateBeta1 <- updateBeta beta1Power beta1
updateBeta2 <- updateBeta beta2Power beta2
TF.group (updateBeta1:updateBeta2:updateVars)