mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Cleanup
This commit is contained in:
parent
5e6237370b
commit
e351dfaacd
|
@ -30,18 +30,16 @@ module TensorFlow.Minimize
|
|||
import Control.Monad (zipWithM)
|
||||
import Data.Default (Default (..))
|
||||
import Data.List (zipWith4)
|
||||
import Data.Maybe (fromMaybe)
|
||||
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Ops as TF (scalar, mul, zerosLike, shape)
|
||||
import qualified TensorFlow.Ops as TF (scalar, mul, zerosLike)
|
||||
import qualified TensorFlow.Variable as TF
|
||||
|
||||
import qualified TensorFlow.Tensor as TF (Rendered, ToTensor)
|
||||
|
||||
import qualified TensorFlow.GenOps.Core as TFO (applyAdam, assignAdd, assign)
|
||||
import qualified TensorFlow.Ops as TFO (assign, initializedVariable,
|
||||
scalar, zeroInitializedVariable)
|
||||
import qualified TensorFlow.Ops as TFO (initializedVariable, zeroInitializedVariable)
|
||||
|
||||
-- | Functions that minimize a loss w.r.t. a set of 'TF.Variable's or 'TF.Tensor TF.Ref's.
|
||||
--
|
||||
|
@ -71,7 +69,7 @@ minimizeWith ::
|
|||
-> TF.Tensor v a -- ^ Loss.
|
||||
-> [t a] -- ^ Parameters of the loss function.
|
||||
-> m TF.ControlNode
|
||||
minimizeWith minimizer loss params = TF.gradients loss params >>= minimize minimizer params
|
||||
minimizeWith m loss params = TF.gradients loss params >>= minimize m params
|
||||
|
||||
-- | Perform one step of the gradient descent algorithm for TF.Variable.
|
||||
gradientDescent ::
|
||||
|
@ -111,15 +109,13 @@ adam = adam' def
|
|||
|
||||
adam' :: AdamConfig -> Minimizer TF.Variable Float TF.Build
|
||||
adam' config =
|
||||
let errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
|
||||
initVal = fromMaybe (error errorMsg) . TF.initializedValue
|
||||
in adam''
|
||||
config
|
||||
(mapM (TF.initializedVariable . TF.zerosLike . TF.readValue))
|
||||
TF.initializedVariable
|
||||
TF.resourceApplyAdam
|
||||
TF.readValue
|
||||
TF.assign
|
||||
adam''
|
||||
config
|
||||
(mapM (TF.initializedVariable . TF.zerosLike . TF.readValue))
|
||||
TF.initializedVariable
|
||||
TF.resourceApplyAdam
|
||||
TF.readValue
|
||||
TF.assign
|
||||
|
||||
adamRef :: [TF.Shape] -> Minimizer (TF.Tensor TF.Ref) Float TF.Build
|
||||
adamRef = adamRef' def
|
||||
|
|
Loading…
Reference in New Issue
Block a user