1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-05 18:49:41 +01:00
tensorflow-haskell/tensorflow-ops/tests/MatrixTest.hs

37 lines
1.3 KiB
Haskell
Raw Normal View History

2017-04-28 02:05:34 +02:00
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
import Control.Monad.IO.Class (liftIO)
import Control.Monad (replicateM_)
2017-04-28 02:05:34 +02:00
import qualified Data.Vector as V
import qualified TensorFlow.Core as TF
2017-06-21 05:50:46 +02:00
import qualified TensorFlow.GenOps.Core as TF (square)
import qualified TensorFlow.Minimize as TF
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
import qualified TensorFlow.Variable as TF
2017-04-28 02:05:34 +02:00
import Test.Framework (defaultMain, Test)
2017-04-28 02:05:34 +02:00
import Test.Framework.Providers.HUnit (testCase)
import TensorFlow.Test (assertAllClose)
randomParam :: TF.Shape -> TF.Session (TF.Tensor TF.Value Float)
randomParam (TF.Shape shape) = TF.truncatedNormal (TF.vector shape)
fitMatrix :: Test
fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
u <- TF.initializedVariable =<< randomParam [2, 1]
v <- TF.initializedVariable =<< randomParam [1, 2]
let ones = [1, 1, 1, 1] :: [Float]
matx = TF.constant [2, 2] ones
diff = matx `TF.sub` (TF.readValue u `TF.matMul` TF.readValue v)
2017-06-21 05:50:46 +02:00
loss = TF.reduceMean $ TF.square diff
trainStep <- TF.minimizeWith (TF.gradientDescent 0.01) loss [u, v]
replicateM_ 1000 (TF.run trainStep)
(u',v') <- TF.run (TF.readValue u, TF.readValue v)
2017-04-28 02:05:34 +02:00
-- ones = u * v
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
main :: IO ()
main = defaultMain [ fitMatrix ]