Support more data types for Adam.

This commit is contained in:
jcmartin 2020-11-05 10:07:03 +00:00
parent 43eebd22ad
commit 0c95543385
1 changed files with 19 additions and 8 deletions

View File

@ -12,6 +12,8 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
@ -22,11 +24,15 @@ module TensorFlow.Minimize
( Minimizer
, minimizeWith
, gradientDescent
, OneOfAdamDataTypes
, AdamConfig(..)
, adam
, adam'
) where
import Data.Complex
import Data.Int
import Data.Word
import Control.Monad (zipWithM)
import Data.Default (Default(..))
import Data.List (zipWith4)
@ -67,14 +73,19 @@ gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $
-- TODO: Support more than Float in adam.
data AdamConfig = AdamConfig
{ adamLearningRate :: Float
, adamBeta1 :: Float
, adamBeta2 :: Float
, adamEpsilon :: Float
data AdamConfig t = AdamConfig
{ adamLearningRate :: t
, adamBeta1 :: t
, adamBeta2 :: t
, adamEpsilon :: t
}
instance Default AdamConfig where
type OneOfAdamDataTypes t =
TF.OneOf '[ Complex Double, Complex Float
, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
, Double, Float] t
instance Fractional t => Default (AdamConfig t) where
-- Recommended defaults from the adam paper.
def = AdamConfig 0.001 0.9 0.999 1e-8
@ -83,10 +94,10 @@ instance Default AdamConfig where
-- See https://arxiv.org/abs/1412.6980.
--
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
adam :: Minimizer Float
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
adam = adam' def
adam' :: AdamConfig -> Minimizer Float
adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
adam' config params grads = TF.withNameScope "adam" $ do
let lr = TF.scalar (adamLearningRate config)
beta1 = TF.scalar (adamBeta1 config)