mirror of
https://github.com/tensorflow/haskell.git
synced 2025-02-17 05:25:05 +01:00
Add Minimize module with gradient descent and adam implementations (#125)
This commit is contained in:
parent
a86d424cac
commit
0603a6987b
8 changed files with 203 additions and 55 deletions
22
README.md
22
README.md
|
@ -20,14 +20,15 @@ Toy example of a linear regression model
|
||||||
([full code](tensorflow-ops/tests/RegressionTest.hs)):
|
([full code](tensorflow-ops/tests/RegressionTest.hs)):
|
||||||
|
|
||||||
```haskell
|
```haskell
|
||||||
import Control.Monad (replicateM, replicateM_, zipWithM)
|
import Control.Monad (replicateM, replicateM_)
|
||||||
import System.Random (randomIO)
|
import System.Random (randomIO)
|
||||||
import Test.HUnit (assertBool)
|
import Test.HUnit (assertBool)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF
|
import qualified TensorFlow.GenOps.Core as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Minimize as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = do
|
main = do
|
||||||
|
@ -48,23 +49,14 @@ fit xData yData = TF.runSession $ do
|
||||||
w <- TF.initializedVariable 0
|
w <- TF.initializedVariable 0
|
||||||
b <- TF.initializedVariable 0
|
b <- TF.initializedVariable 0
|
||||||
-- Define the loss function.
|
-- Define the loss function.
|
||||||
let yHat = (x `TF.mul` w) `TF.add` b
|
let yHat = (x `TF.mul` TF.readValue w) `TF.add` TF.readValue b
|
||||||
loss = TF.square (yHat `TF.sub` y)
|
loss = TF.square (yHat `TF.sub` y)
|
||||||
-- Optimize with gradient descent.
|
-- Optimize with gradient descent.
|
||||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
trainStep <- TF.minimizeWith (TF.gradientDescent 0.001) loss [w, b]
|
||||||
replicateM_ 1000 (TF.run trainStep)
|
replicateM_ 1000 (TF.run trainStep)
|
||||||
-- Return the learned parameters.
|
-- Return the learned parameters.
|
||||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
(TF.Scalar w', TF.Scalar b') <- TF.run (TF.readValue w, TF.readValue b)
|
||||||
return (w', b')
|
return (w', b')
|
||||||
|
|
||||||
gradientDescent :: Float
|
|
||||||
-> TF.Tensor TF.Build Float
|
|
||||||
-> [TF.Tensor TF.Ref Float]
|
|
||||||
-> TF.Session TF.ControlNode
|
|
||||||
gradientDescent alpha loss params = do
|
|
||||||
let applyGrad param grad =
|
|
||||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
|
||||||
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params
|
|
||||||
```
|
```
|
||||||
|
|
||||||
# Installation Instructions
|
# Installation Instructions
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
import Control.Monad (zipWithM, when, forM_)
|
import Control.Monad (forM_, when)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
import Data.List (genericLength)
|
||||||
|
@ -23,9 +23,9 @@ import qualified Data.Text.IO as T
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
|
||||||
import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable)
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable, zeroInitializedVariable)
|
||||||
import qualified TensorFlow.Variable as TF
|
import qualified TensorFlow.Variable as TF
|
||||||
|
import qualified TensorFlow.Minimize as TF
|
||||||
|
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
@ -87,11 +87,7 @@ createModel = do
|
||||||
loss =
|
loss =
|
||||||
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||||
grads <- TF.gradients loss params
|
trainStep <- TF.minimizeWith TF.adam loss params
|
||||||
|
|
||||||
let lr = TF.scalar 0.00001
|
|
||||||
applyGrad param grad = TF.assignAdd param (negate $ lr `TF.mul` grad)
|
|
||||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
|
||||||
|
|
||||||
let correctPredictions = TF.equal predict labels
|
let correctPredictions = TF.equal predict labels
|
||||||
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
||||||
|
|
|
@ -22,7 +22,8 @@
|
||||||
{-# LANGUAGE ViewPatterns #-}
|
{-# LANGUAGE ViewPatterns #-}
|
||||||
|
|
||||||
module TensorFlow.Gradient
|
module TensorFlow.Gradient
|
||||||
( gradients
|
( GradientCompatible
|
||||||
|
, gradients
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad (forM, zipWithM)
|
import Control.Monad (forM, zipWithM)
|
||||||
|
|
115
tensorflow-ops/src/TensorFlow/Minimize.hs
Normal file
115
tensorflow-ops/src/TensorFlow/Minimize.hs
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
-- Copyright 2016 TensorFlow authors.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeApplications #-}
|
||||||
|
|
||||||
|
module TensorFlow.Minimize
|
||||||
|
( Minimizer
|
||||||
|
, minimizeWith
|
||||||
|
, gradientDescent
|
||||||
|
, AdamConfig(..)
|
||||||
|
, adam
|
||||||
|
, adam'
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad (zipWithM)
|
||||||
|
import Data.Default (Default(..))
|
||||||
|
import Data.List (zipWith4)
|
||||||
|
import Data.Maybe (fromMaybe)
|
||||||
|
|
||||||
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Ops as TF hiding (assign, initializedVariable)
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
|
-- | Functions that minimize a loss w.r.t. a set of 'TF.Variable's.
|
||||||
|
--
|
||||||
|
-- Generally only performs one step of an iterative algorithm.
|
||||||
|
--
|
||||||
|
-- 'Minimizer's are defined as a function of the gradients instead of
|
||||||
|
-- the loss so that users can apply transformations to the gradients.
|
||||||
|
type Minimizer a =
|
||||||
|
forall m. TF.MonadBuild m =>
|
||||||
|
[TF.Variable a] -> [TF.Tensor TF.Value a] -> m TF.ControlNode
|
||||||
|
|
||||||
|
-- | Convenience wrapper around 'TF.gradients' and a 'Minimizer'.
|
||||||
|
minimizeWith :: (TF.MonadBuild m, TF.GradientCompatible a)
|
||||||
|
=> Minimizer a
|
||||||
|
-> TF.Tensor v a -- ^ Loss.
|
||||||
|
-> [TF.Variable a] -- ^ Parameters of the loss function.
|
||||||
|
-> m TF.ControlNode
|
||||||
|
minimizeWith minimizer loss params =
|
||||||
|
TF.gradients loss params >>= minimizer params
|
||||||
|
|
||||||
|
-- | Perform one step of the gradient descent algorithm.
|
||||||
|
gradientDescent :: TF.GradientCompatible a
|
||||||
|
=> a -- ^ Learning rate.
|
||||||
|
-> Minimizer a
|
||||||
|
gradientDescent learningRate params grads = TF.withNameScope "gradientDescent" $ do
|
||||||
|
let applyGrad param grad =
|
||||||
|
TF.assignAdd param (TF.scalar (-learningRate) `TF.mul` grad)
|
||||||
|
TF.group =<< zipWithM applyGrad params grads
|
||||||
|
|
||||||
|
-- TODO: Support more than Float in adam.
|
||||||
|
|
||||||
|
data AdamConfig = AdamConfig
|
||||||
|
{ adamLearningRate :: Float
|
||||||
|
, adamBeta1 :: Float
|
||||||
|
, adamBeta2 :: Float
|
||||||
|
, adamEpsilon :: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
instance Default AdamConfig where
|
||||||
|
-- Recommended defaults from the adam paper.
|
||||||
|
def = AdamConfig 0.001 0.9 0.999 1e-8
|
||||||
|
|
||||||
|
-- | Perform one step of the adam algorithm.
|
||||||
|
--
|
||||||
|
-- See https://arxiv.org/abs/1412.6980.
|
||||||
|
--
|
||||||
|
-- NOTE: Currently requires all 'TF.Variable's to have an 'TF.initializedValue'.
|
||||||
|
adam :: Minimizer Float
|
||||||
|
adam = adam' def
|
||||||
|
|
||||||
|
adam' :: AdamConfig -> Minimizer Float
|
||||||
|
adam' config params grads = TF.withNameScope "adam" $ do
|
||||||
|
let lr = TF.scalar (adamLearningRate config)
|
||||||
|
beta1 = TF.scalar (adamBeta1 config)
|
||||||
|
beta2 = TF.scalar (adamBeta2 config)
|
||||||
|
epsilon = TF.scalar (adamEpsilon config)
|
||||||
|
-- Create adam state variables.
|
||||||
|
let errorMsg = "TensorFlow.Minimize.adam requires an initial value for all variables"
|
||||||
|
initVal = fromMaybe (error errorMsg) . TF.initializedValue
|
||||||
|
ms <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
|
||||||
|
vs <- mapM (TF.initializedVariable . TF.zerosLike . initVal) params
|
||||||
|
beta1Power <- TF.initializedVariable beta1
|
||||||
|
beta2Power <- TF.initializedVariable beta2
|
||||||
|
-- Perform adam update.
|
||||||
|
let applyGrad param m v =
|
||||||
|
TF.resourceApplyAdam param m v
|
||||||
|
(TF.readValue beta1Power)
|
||||||
|
(TF.readValue beta2Power)
|
||||||
|
lr beta1 beta2 epsilon
|
||||||
|
updateVars <- sequence $ zipWith4 applyGrad params ms vs grads
|
||||||
|
-- Update beta variables after adam update.
|
||||||
|
let updateBeta betaPower beta =
|
||||||
|
TF.withControlDependencies updateVars
|
||||||
|
(TF.assign betaPower (TF.readValue betaPower `TF.mul` beta))
|
||||||
|
updateBeta1 <- updateBeta beta1Power beta1
|
||||||
|
updateBeta2 <- updateBeta beta2Power beta2
|
||||||
|
TF.group (updateBeta1:updateBeta2:updateVars)
|
|
@ -6,6 +6,8 @@
|
||||||
-- TODO: given that distinction, figure out a good story around
|
-- TODO: given that distinction, figure out a good story around
|
||||||
-- gradients and save/restore. Then, merge this module into
|
-- gradients and save/restore. Then, merge this module into
|
||||||
-- TensorFlow.Ops.
|
-- TensorFlow.Ops.
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE RecursiveDo #-}
|
{-# LANGUAGE RecursiveDo #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
@ -23,8 +25,13 @@ module TensorFlow.Variable
|
||||||
, assign'
|
, assign'
|
||||||
, assignAdd
|
, assignAdd
|
||||||
, assignAdd'
|
, assignAdd'
|
||||||
|
, resourceApplyAdam
|
||||||
|
, resourceApplyAdam'
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import qualified Data.Complex
|
||||||
|
import qualified Data.Int
|
||||||
|
import qualified Data.Word
|
||||||
import Data.Text.Encoding (encodeUtf8)
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
import Lens.Family2 ((.~), (&))
|
import Lens.Family2 ((.~), (&))
|
||||||
import TensorFlow.Core
|
import TensorFlow.Core
|
||||||
|
@ -133,3 +140,55 @@ assignAdd = assignAdd' id
|
||||||
assignAdd' :: (MonadBuild m, TensorType a)
|
assignAdd' :: (MonadBuild m, TensorType a)
|
||||||
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||||
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
|
assignAdd' params (Variable h _) v = CoreOps.assignAddVariableOp' params h v
|
||||||
|
|
||||||
|
-- | Update '*var' according to the Adam algorithm.
|
||||||
|
--
|
||||||
|
-- lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
|
||||||
|
-- m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
|
||||||
|
-- v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
|
||||||
|
-- variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
|
||||||
|
resourceApplyAdam ::
|
||||||
|
(MonadBuild m,
|
||||||
|
OneOf '[(Data.Complex.Complex Double),
|
||||||
|
(Data.Complex.Complex Float),
|
||||||
|
Data.Int.Int16,
|
||||||
|
Data.Int.Int32,
|
||||||
|
Data.Int.Int64, Data.Int.Int8,
|
||||||
|
Data.Word.Word16,
|
||||||
|
Data.Word.Word8, Double,
|
||||||
|
Float] t)
|
||||||
|
=> Variable t -- ^ __var__: Should be from a Variable().
|
||||||
|
-> Variable t -- ^ __m__: Should be from a Variable().
|
||||||
|
-> Variable t -- ^ __v__: Should be from a Variable().
|
||||||
|
-> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
|
||||||
|
-> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
|
||||||
|
-> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
|
||||||
|
-> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
|
||||||
|
-> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
|
||||||
|
-> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
|
||||||
|
-> Tensor v7 t -- ^ __grad__: The gradient.
|
||||||
|
-> m (ControlNode)
|
||||||
|
resourceApplyAdam = resourceApplyAdam' id
|
||||||
|
|
||||||
|
resourceApplyAdam' ::
|
||||||
|
(MonadBuild m,
|
||||||
|
OneOf '[(Data.Complex.Complex Double),
|
||||||
|
(Data.Complex.Complex Float),
|
||||||
|
Data.Int.Int16, Data.Int.Int32,
|
||||||
|
Data.Int.Int64, Data.Int.Int8,
|
||||||
|
Data.Word.Word16, Data.Word.Word8, Double,
|
||||||
|
Float] t)
|
||||||
|
=> OpParams
|
||||||
|
-> Variable t -- ^ __var__: Should be from a Variable().
|
||||||
|
-> Variable t -- ^ __m__: Should be from a Variable().
|
||||||
|
-> Variable t -- ^ __v__: Should be from a Variable().
|
||||||
|
-> Tensor v1 t -- ^ __beta1_power__: Must be a scalar.
|
||||||
|
-> Tensor v2 t -- ^ __beta2_power__: Must be a scalar.
|
||||||
|
-> Tensor v3 t -- ^ __lr__: Scaling factor. Must be a scalar.
|
||||||
|
-> Tensor v4 t -- ^ __beta1__: Momentum factor. Must be a scalar.
|
||||||
|
-> Tensor v5 t -- ^ __beta2__: Momentum factor. Must be a scalar.
|
||||||
|
-> Tensor v6 t -- ^ __epsilon__: Ridge term. Must be a scalar.
|
||||||
|
-> Tensor v7 t -- ^ __grad__: The gradient.
|
||||||
|
-> m (ControlNode)
|
||||||
|
resourceApplyAdam' params (Variable var _) (Variable m _) (Variable v _) =
|
||||||
|
CoreOps.resourceApplyAdam' params var m v
|
||||||
|
|
|
@ -17,6 +17,7 @@ library
|
||||||
exposed-modules: TensorFlow.Gradient
|
exposed-modules: TensorFlow.Gradient
|
||||||
, TensorFlow.Ops
|
, TensorFlow.Ops
|
||||||
, TensorFlow.EmbeddingOps
|
, TensorFlow.EmbeddingOps
|
||||||
|
, TensorFlow.Minimize
|
||||||
, TensorFlow.NN
|
, TensorFlow.NN
|
||||||
, TensorFlow.Queue
|
, TensorFlow.Queue
|
||||||
, TensorFlow.Variable
|
, TensorFlow.Variable
|
||||||
|
|
|
@ -2,13 +2,14 @@
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Control.Monad (replicateM_, zipWithM)
|
import Control.Monad (replicateM_)
|
||||||
|
|
||||||
import qualified TensorFlow.GenOps.Core as TF (square, rank)
|
|
||||||
import qualified TensorFlow.Core as TF
|
|
||||||
import qualified TensorFlow.Gradient as TF
|
|
||||||
import qualified TensorFlow.Ops as TF
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as TF (square, rank)
|
||||||
|
import qualified TensorFlow.Minimize as TF
|
||||||
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
import Test.Framework (defaultMain, Test)
|
import Test.Framework (defaultMain, Test)
|
||||||
import Test.Framework.Providers.HUnit (testCase)
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
@ -26,22 +27,13 @@ fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
|
||||||
v <- TF.initializedVariable =<< randomParam [1, 2]
|
v <- TF.initializedVariable =<< randomParam [1, 2]
|
||||||
let ones = [1, 1, 1, 1] :: [Float]
|
let ones = [1, 1, 1, 1] :: [Float]
|
||||||
matx = TF.constant [2, 2] ones
|
matx = TF.constant [2, 2] ones
|
||||||
diff = matx `TF.sub` (u `TF.matMul` v)
|
diff = matx `TF.sub` (TF.readValue u `TF.matMul` TF.readValue v)
|
||||||
loss = reduceMean $ TF.square diff
|
loss = reduceMean $ TF.square diff
|
||||||
trainStep <- gradientDescent 0.01 loss [u, v]
|
trainStep <- TF.minimizeWith (TF.gradientDescent 0.01) loss [u, v]
|
||||||
replicateM_ 1000 (TF.run trainStep)
|
replicateM_ 1000 (TF.run trainStep)
|
||||||
(u',v') <- TF.run (u, v)
|
(u',v') <- TF.run (TF.readValue u, TF.readValue v)
|
||||||
-- ones = u * v
|
-- ones = u * v
|
||||||
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
|
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
|
||||||
|
|
||||||
gradientDescent :: Float
|
|
||||||
-> TF.Tensor TF.Build Float
|
|
||||||
-> [TF.Tensor TF.Ref Float]
|
|
||||||
-> TF.Session TF.ControlNode
|
|
||||||
gradientDescent alpha loss params = do
|
|
||||||
let applyGrad param grad =
|
|
||||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
|
||||||
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = defaultMain [ fitMatrix ]
|
main = defaultMain [ fitMatrix ]
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
-- | Simple linear regression example for the README.
|
-- | Simple linear regression example for the README.
|
||||||
|
|
||||||
import Control.Monad (replicateM, replicateM_, zipWithM)
|
import Control.Monad (replicateM, replicateM_)
|
||||||
import System.Random (randomIO)
|
import System.Random (randomIO)
|
||||||
import Test.HUnit (assertBool)
|
import Test.HUnit (assertBool)
|
||||||
|
|
||||||
import qualified TensorFlow.Core as TF
|
import qualified TensorFlow.Core as TF
|
||||||
import qualified TensorFlow.GenOps.Core as TF
|
import qualified TensorFlow.GenOps.Core as TF
|
||||||
import qualified TensorFlow.Gradient as TF
|
import qualified TensorFlow.Minimize as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
|
||||||
|
import qualified TensorFlow.Variable as TF
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = do
|
main = do
|
||||||
|
@ -28,20 +29,11 @@ fit xData yData = TF.runSession $ do
|
||||||
w <- TF.initializedVariable 0
|
w <- TF.initializedVariable 0
|
||||||
b <- TF.initializedVariable 0
|
b <- TF.initializedVariable 0
|
||||||
-- Define the loss function.
|
-- Define the loss function.
|
||||||
let yHat = (x `TF.mul` w) `TF.add` b
|
let yHat = (x `TF.mul` TF.readValue w) `TF.add` TF.readValue b
|
||||||
loss = TF.square (yHat `TF.sub` y)
|
loss = TF.square (yHat `TF.sub` y)
|
||||||
-- Optimize with gradient descent.
|
-- Optimize with gradient descent.
|
||||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
trainStep <- TF.minimizeWith (TF.gradientDescent 0.001) loss [w, b]
|
||||||
replicateM_ 1000 (TF.run trainStep)
|
replicateM_ 1000 (TF.run trainStep)
|
||||||
-- Return the learned parameters.
|
-- Return the learned parameters.
|
||||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
(TF.Scalar w', TF.Scalar b') <- TF.run (TF.readValue w, TF.readValue b)
|
||||||
return (w', b')
|
return (w', b')
|
||||||
|
|
||||||
gradientDescent :: Float
|
|
||||||
-> TF.Tensor TF.Build Float
|
|
||||||
-> [TF.Tensor TF.Ref Float]
|
|
||||||
-> TF.Session TF.ControlNode
|
|
||||||
gradientDescent alpha loss params = do
|
|
||||||
let applyGrad param grad =
|
|
||||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
|
||||||
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue