mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-18 17:09:43 +01:00
116 lines
4.3 KiB
Haskell
116 lines
4.3 KiB
Haskell
|
-- Copyright 2016 TensorFlow authors.
|
||
|
--
|
||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
-- you may not use this file except in compliance with the License.
|
||
|
-- You may obtain a copy of the License at
|
||
|
--
|
||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||
|
--
|
||
|
-- Unless required by applicable law or agreed to in writing, software
|
||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
-- See the License for the specific language governing permissions and
|
||
|
-- limitations under the License.
|
||
|
|
||
|
{-# LANGUAGE FlexibleContexts #-}
|
||
|
{-# LANGUAGE OverloadedStrings #-}
|
||
|
{-# LANGUAGE RankNTypes #-}
|
||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||
|
{-# LANGUAGE TypeApplications #-}
|
||
|
|
||
|
module TensorFlow.Minimize
|
||
|
( Minimizer
|
||
|
, minimizeWith
|
||
|
, gradientDescent
|
||
|
, AdamConfig(..)
|
||
|
, adam
|
||
|
, adam'
|
||
|
) where
|
||
|
|
||
|
import Control.Monad (zipWithM)
|
||
|
import Data.Default (Default(..))
|
||
|
import Data.List (zipWith4)
|
||
|
import Data.Maybe (fromMaybe)
|
||
|
|
||
|
import qualified TensorFlow.Core as TF
|
||
|
import qualified TensorFlow.Gradient as TF
|
||
|
import qualified TensorFlow.Ops as TF hiding (assign, initializedVariable)
|
||
|
import qualified TensorFlow.Variable as TF
|
||
|
|
||
|
-- | Functions that minimize a loss w.r.t. a set of 'TF.Variable's.
|
||
|
--
|
||
|
-- Generally only performs one step of an iterative algorithm.
|
||
|
--
|
||
|
-- 'Minimizer's are defined as a function of the gradients instead of
|
||
|
-- the loss so that users can apply transformations to the gradients.
|
||
|
type Minimizer a =
|
||
|
forall m. TF.MonadBuild m =>
|
||
|
[TF.Variable a] -> [TF.Tensor TF.Value a] -> m TF.ControlNode
|
||
|
|
||
|
-- | Convenience wrapper around 'TF.gradients' and a 'Minimizer'.
|
||
|
minimizeWith :: (TF.MonadBuild m, TF.GradientCompatible a)
|
||
|
=> Minimizer a
|
||
|
-> TF.Tensor v a -- ^ Loss.
|
||
|
-> [TF.Variable a] -- ^ Parameters of the loss function.
|
||
|
-> m TF.ControlNode
|
||
|
minimizeWith minimizer loss params =
|
||
|
TF.gradients loss params >>= minimizer params
|
||
|
|
||
|
-- | Perform one step of the gradient descent algorithm.
|
||
|
gradientDescent :: TF.GradientCompatible a
|
||
|
=> a -- ^ Learning rate.
|
||
|
-> Minimizer a
|
||
|
gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $ do
|
||
|
let applyGrad param grad =
|
||
|
TF.assignAdd param (TF.scalar (-learningRate) `TF.mul` grad)
|
||
|
TF.group =<< zipWithM applyGrad params grads
|
||
|
|
||
|
-- TODO: Support more than Float in adam.
|
||
|
|
||
|
data AdamConfig = AdamConfig
|
||
|
{ adamLearningRate :: Float
|
||
|
, adamBeta1 :: Float
|
||
|
, adamBeta2 :: Float
|
||
|
, adamEpsilon :: Float
|
||
|
}
|
||
|
|
||
|
instance Default AdamConfig where
|
||
|
-- Recommended defaults from the adam paper.
|
||
|
def = AdamConfig 0.001 0.9 0.999 1e-8
|
||
|
|
||
|
-- | Perform one step of the adam algorithm.
|
||
|
--
|
||
|
-- See https://arxiv.org/abs/1412.6980.
|
||
|
--
|
||
|
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
||
|
adam :: Minimizer Float
|
||
|
adam = adam' def
|
||
|
|
||
|
adam' :: AdamConfig -> Minimizer Float
|
||
|
adam' config params grads = TF.withNameScope "adam" $ do
|
||
|
let lr = TF.scalar (adamLearningRate config)
|
||
|
beta1 = TF.scalar (adamBeta1 config)
|
||
|
beta2 = TF.scalar (adamBeta2 config)
|
||
|
epsilon = TF.scalar (adamEpsilon config)
|
||
|
-- Create adam state variables.
|
||
|
let errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
|
||
|
initVal = fromMaybe (error errorMsg) . TF.initializedValue
|
||
|
ms <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
|
||
|
vs <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
|
||
|
beta1Power <- TF.initializedVariable beta1
|
||
|
beta2Power <- TF.initializedVariable beta2
|
||
|
-- Perform adam update.
|
||
|
let applyGrad param m v =
|
||
|
TF.resourceApplyAdam param m v
|
||
|
(TF.readValue beta1Power)
|
||
|
(TF.readValue beta2Power)
|
||
|
lr beta1 beta2 epsilon
|
||
|
updateVars <- sequence $ zipWith4 applyGrad params ms vs grads
|
||
|
-- Update beta variables after adam update.
|
||
|
let updateBeta betaPower beta =
|
||
|
TF.withControlDependencies updateVars
|
||
|
(TF.assign betaPower (TF.readValue betaPower `TF.mul` beta))
|
||
|
updateBeta1 <- updateBeta beta1Power beta1
|
||
|
updateBeta2 <- updateBeta beta2Power beta2
|
||
|
TF.group (updateBeta1:updateBeta2:updateVars)
|