Support more data types for Adam.
This commit is contained in:
parent
43eebd22ad
commit
0c95543385
|
@ -12,6 +12,8 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
@ -22,11 +24,15 @@ module TensorFlow.Minimize
|
|||
( Minimizer
|
||||
, minimizeWith
|
||||
, gradientDescent
|
||||
, OneOfAdamDataTypes
|
||||
, AdamConfig(..)
|
||||
, adam
|
||||
, adam'
|
||||
) where
|
||||
|
||||
import Data.Complex
|
||||
import Data.Int
|
||||
import Data.Word
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Default (Default(..))
|
||||
import Data.List (zipWith4)
|
||||
|
@ -67,14 +73,19 @@ gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $
|
|||
|
||||
-- TODO: Support more than Float in adam.
|
||||
|
||||
data AdamConfig = AdamConfig
|
||||
{ adamLearningRate :: Float
|
||||
, adamBeta1 :: Float
|
||||
, adamBeta2 :: Float
|
||||
, adamEpsilon :: Float
|
||||
data AdamConfig t = AdamConfig
|
||||
{ adamLearningRate :: t
|
||||
, adamBeta1 :: t
|
||||
, adamBeta2 :: t
|
||||
, adamEpsilon :: t
|
||||
}
|
||||
|
||||
instance Default AdamConfig where
|
||||
type OneOfAdamDataTypes t =
|
||||
TF.OneOf '[ Complex Double, Complex Float
|
||||
, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
|
||||
, Double, Float] t
|
||||
|
||||
instance Fractional t => Default (AdamConfig t) where
|
||||
-- Recommended defaults from the adam paper.
|
||||
def = AdamConfig 0.001 0.9 0.999 1e-8
|
||||
|
||||
|
@ -83,10 +94,10 @@ instance Default AdamConfig where
|
|||
-- See https://arxiv.org/abs/1412.6980.
|
||||
--
|
||||
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
||||
adam :: Minimizer Float
|
||||
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
|
||||
adam = adam' def
|
||||
|
||||
adam' :: AdamConfig -> Minimizer Float
|
||||
adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
|
||||
adam' config params grads = TF.withNameScope "adam" $ do
|
||||
let lr = TF.scalar (adamLearningRate config)
|
||||
beta1 = TF.scalar (adamBeta1 config)
|
||||
|
|
Loading…
Reference in New Issue