mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 14:59:44 +01:00
2c5c879037
This change adds a class that both `Build` and `Session` are instances of: class MonadBuild m where build :: Build a -> m a All stateful ops (generated and manually written) now have a signature that returns an instance of `MonadBuild` (rather than just `Build`). For example: assign_ :: (MonadBuild m, TensorType t) => Tensor Ref t -> Tensor v t -> m (Tensor Ref t) This lets us remove a bunch of spurious calls to `build` in user code. It also lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run` (or `run =<< foo`, which is sometimes nicer when foo is a complicated expression). I went ahead and deleted `buildAnd` altogether since it seems to lead to confusion; in particular a few tests had `buildAnd run . pure` which is actually equivalent to just `run`.
47 lines
1.6 KiB
Haskell
47 lines
1.6 KiB
Haskell
-- | Simple linear regression example for the README.
|
|
|
|
import Control.Monad (replicateM, replicateM_, zipWithM)
|
|
import System.Random (randomIO)
|
|
import Test.HUnit (assertBool)
|
|
|
|
import qualified TensorFlow.Core as TF
|
|
import qualified TensorFlow.GenOps.Core as TF
|
|
import qualified TensorFlow.Gradient as TF
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
main :: IO ()
|
|
main = do
|
|
-- Generate data where `y = x*3 + 8`.
|
|
xData <- replicateM 100 randomIO
|
|
let yData = [x*3 + 8 | x <- xData]
|
|
-- Fit linear regression model.
|
|
(w, b) <- fit xData yData
|
|
assertBool "w == 3" (abs (3 - w) < 0.001)
|
|
assertBool "b == 8" (abs (8 - b) < 0.001)
|
|
|
|
fit :: [Float] -> [Float] -> IO (Float, Float)
|
|
fit xData yData = TF.runSession $ do
|
|
-- Create tensorflow constants for x and y.
|
|
let x = TF.vector xData
|
|
y = TF.vector yData
|
|
-- Create scalar variables for slope and intercept.
|
|
w <- TF.initializedVariable 0
|
|
b <- TF.initializedVariable 0
|
|
-- Define the loss function.
|
|
let yHat = (x `TF.mul` w) `TF.add` b
|
|
loss = TF.square (yHat `TF.sub` y)
|
|
-- Optimize with gradient descent.
|
|
trainStep <- gradientDescent 0.001 loss [w, b]
|
|
replicateM_ 1000 (TF.run trainStep)
|
|
-- Return the learned parameters.
|
|
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
|
return (w', b')
|
|
|
|
gradientDescent :: Float
|
|
-> TF.Tensor TF.Value Float
|
|
-> [TF.Tensor TF.Ref Float]
|
|
-> TF.Session TF.ControlNode
|
|
gradientDescent alpha loss params = do
|
|
let applyGrad param grad =
|
|
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
|
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params
|