1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00
This commit is contained in:
Rik van der Kleij 2019-02-04 17:41:27 +01:00
parent 5e6237370b
commit e351dfaacd

View File

@ -30,18 +30,16 @@ module TensorFlow.Minimize
import Control.Monad (zipWithM)
import Data.Default (Default (..))
import Data.List (zipWith4)
import Data.Maybe (fromMaybe)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF (scalar, mul, zerosLike, shape)
import qualified TensorFlow.Ops as TF (scalar, mul, zerosLike)
import qualified TensorFlow.Variable as TF
import qualified TensorFlow.Tensor as TF (Rendered, ToTensor)
import qualified TensorFlow.GenOps.Core as TFO (applyAdam, assignAdd, assign)
import qualified TensorFlow.Ops as TFO (assign, initializedVariable,
scalar, zeroInitializedVariable)
import qualified TensorFlow.Ops as TFO (initializedVariable, zeroInitializedVariable)
-- | Functions that minimize a loss w.r.t. a set of 'TF.Variable's or 'TF.Tensor TF.Ref's.
--
@ -71,7 +69,7 @@ minimizeWith ::
-> TF.Tensor v a -- ^ Loss.
-> [t a] -- ^ Parameters of the loss function.
-> m TF.ControlNode
minimizeWith minimizer loss params = TF.gradients loss params >>= minimize minimizer params
minimizeWith m loss params = TF.gradients loss params >>= minimize m params
-- | Perform one step of the gradient descent algorithm for TF.Variable.
gradientDescent ::
@ -111,15 +109,13 @@ adam = adam' def
adam' :: AdamConfig -> Minimizer TF.Variable Float TF.Build
adam' config =
let errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
initVal = fromMaybe (error errorMsg) . TF.initializedValue
in adam''
config
(mapM (TF.initializedVariable . TF.zerosLike . TF.readValue))
TF.initializedVariable
TF.resourceApplyAdam
TF.readValue
TF.assign
adam''
config
(mapM (TF.initializedVariable . TF.zerosLike . TF.readValue))
TF.initializedVariable
TF.resourceApplyAdam
TF.readValue
TF.assign
adamRef :: [TF.Shape] -> Minimizer (TF.Tensor TF.Ref) Float TF.Build
adamRef = adamRef' def