1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-18 00:49:44 +01:00
tensorflow-haskell/tensorflow-ops/tests/RegressionTest.hs
Judah Jacobson d62c614695 Distinguish between "rendered" and "unrendered" Tensors. (#88)
Distinguish between "rendered" and "unrendered" Tensors.

There are now three types of `Tensor`:

- `Tensor Value a`: rendered value
- `Tensor Ref a`: rendered reference
- `Tensor Build a` : unrendered value

The extra bookkeeping makes it easier to track (and enforce) which tensors are
rendered or not.  For examples where this has been confusing in the past, see

With this change, pure ops look similar to before, returning `Tensor Build`
instead of `Tensor Value`.  "Stateful" (monadic) ops are unchanged.  For
example:

    add :: OneOf [..] t => Tensor v'1 t -> Tensor v'2 t -> Tensor Build t
    assign :: (MonadBuild m, TensorType t)
           => Tensor Ref t -> Tensor v'2 t -> m (Tensor Ref t)

The `gradients` function now requires that the variables over which it's
differentiating are pre-rendered:

    gradients :: (..., Rendered v2) => Tensor v1 a -> [Tensor v2 a]
              -> m [Tensor Value a]

(`Rendered v2` means that `v2` is either a `Ref` or a `Value`.)

Additionally, the implementation of `gradients` now takes care to render every
intermediate value when performing the reverse accumulation.  I suspect this
fixes an exponential blowup for complicated expressions.
2017-04-06 15:10:33 -07:00

47 lines
1.6 KiB
Haskell

-- | Simple linear regression example for the README.
import Control.Monad (replicateM, replicateM_, zipWithM)
import System.Random (randomIO)
import Test.HUnit (assertBool)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF
main :: IO ()
main = do
-- Generate data where `y = x*3 + 8`.
xData <- replicateM 100 randomIO
let yData = [x*3 + 8 | x <- xData]
-- Fit linear regression model.
(w, b) <- fit xData yData
assertBool "w == 3" (abs (3 - w) < 0.001)
assertBool "b == 8" (abs (8 - b) < 0.001)
fit :: [Float] -> [Float] -> IO (Float, Float)
fit xData yData = TF.runSession $ do
-- Create tensorflow constants for x and y.
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
return (w', b')
gradientDescent :: Float
-> TF.Tensor TF.Build Float
-> [TF.Tensor TF.Ref Float]
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params