diff --git a/tensorflow-ops/src/TensorFlow/Minimize.hs b/tensorflow-ops/src/TensorFlow/Minimize.hs index 0062d54..f068be1 100644 --- a/tensorflow-ops/src/TensorFlow/Minimize.hs +++ b/tensorflow-ops/src/TensorFlow/Minimize.hs @@ -12,6 +12,8 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} @@ -22,11 +24,15 @@ module TensorFlow.Minimize ( Minimizer , minimizeWith , gradientDescent + , OneOfAdamDataTypes , AdamConfig(..) , adam , adam' ) where +import Data.Complex +import Data.Int +import Data.Word import Control.Monad (zipWithM) import Data.Default (Default(..)) import Data.List (zipWith4) @@ -67,14 +73,19 @@ gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $ -- TODO: Support more than Float in adam. -data AdamConfig = AdamConfig - { adamLearningRate :: Float - , adamBeta1 :: Float - , adamBeta2 :: Float - , adamEpsilon :: Float +data AdamConfig t = AdamConfig + { adamLearningRate :: t + , adamBeta1 :: t + , adamBeta2 :: t + , adamEpsilon :: t } -instance Default AdamConfig where +type OneOfAdamDataTypes t = + TF.OneOf '[ Complex Double, Complex Float + , Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8 + , Double, Float] t + +instance Fractional t => Default (AdamConfig t) where -- Recommended defaults from the adam paper. def = AdamConfig 0.001 0.9 0.999 1e-8 @@ -83,10 +94,10 @@ instance Default AdamConfig where -- See https://arxiv.org/abs/1412.6980. -- -- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'. -adam :: Minimizer Float +adam :: (OneOfAdamDataTypes t, Fractional t) => Minimizer t adam = adam' def -adam' :: AdamConfig -> Minimizer Float +adam' :: OneOfAdamDataTypes t => AdamConfig t -> Minimizer t adam' config params grads = TF.withNameScope "adam" $ do let lr = TF.scalar (adamLearningRate config) beta1 = TF.scalar (adamBeta1 config)