mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Support more data types for Adam.
This commit is contained in:
parent
43eebd22ad
commit
0c95543385
|
@ -12,6 +12,8 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
@ -22,11 +24,15 @@ module TensorFlow.Minimize
|
||||||
( Minimizer
|
( Minimizer
|
||||||
, minimizeWith
|
, minimizeWith
|
||||||
, gradientDescent
|
, gradientDescent
|
||||||
|
, OneOfAdamDataTypes
|
||||||
, AdamConfig(..)
|
, AdamConfig(..)
|
||||||
, adam
|
, adam
|
||||||
, adam'
|
, adam'
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.Complex
|
||||||
|
import Data.Int
|
||||||
|
import Data.Word
|
||||||
import Control.Monad (zipWithM)
|
import Control.Monad (zipWithM)
|
||||||
import Data.Default (Default(..))
|
import Data.Default (Default(..))
|
||||||
import Data.List (zipWith4)
|
import Data.List (zipWith4)
|
||||||
|
@ -67,14 +73,19 @@ gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $
|
||||||
|
|
||||||
-- TODO: Support more than Float in adam.
|
-- TODO: Support more than Float in adam.
|
||||||
|
|
||||||
data AdamConfig = AdamConfig
|
data AdamConfig t = AdamConfig
|
||||||
{ adamLearningRate :: Float
|
{ adamLearningRate :: t
|
||||||
, adamBeta1 :: Float
|
, adamBeta1 :: t
|
||||||
, adamBeta2 :: Float
|
, adamBeta2 :: t
|
||||||
, adamEpsilon :: Float
|
, adamEpsilon :: t
|
||||||
}
|
}
|
||||||
|
|
||||||
instance Default AdamConfig where
|
type OneOfAdamDataTypes t =
|
||||||
|
TF.OneOf '[ Complex Double, Complex Float
|
||||||
|
, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8
|
||||||
|
, Double, Float] t
|
||||||
|
|
||||||
|
instance Fractional t => Default (AdamConfig t) where
|
||||||
-- Recommended defaults from the adam paper.
|
-- Recommended defaults from the adam paper.
|
||||||
def = AdamConfig 0.001 0.9 0.999 1e-8
|
def = AdamConfig 0.001 0.9 0.999 1e-8
|
||||||
|
|
||||||
|
@ -83,10 +94,10 @@ instance Default AdamConfig where
|
||||||
-- See https://arxiv.org/abs/1412.6980.
|
-- See https://arxiv.org/abs/1412.6980.
|
||||||
--
|
--
|
||||||
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
||||||
adam :: Minimizer Float
|
adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t
|
||||||
adam = adam' def
|
adam = adam' def
|
||||||
|
|
||||||
adam' :: AdamConfig -> Minimizer Float
|
adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t
|
||||||
adam' config params grads = TF.withNameScope "adam" $ do
|
adam' config params grads = TF.withNameScope "adam" $ do
|
||||||
let lr = TF.scalar (adamLearningRate config)
|
let lr = TF.scalar (adamLearningRate config)
|
||||||
beta1 = TF.scalar (adamBeta1 config)
|
beta1 = TF.scalar (adamBeta1 config)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user