1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-27 03:59:46 +01:00
tensorflow-haskell/tensorflow-ops/tests/MatrixTest.hs
2017-06-20 20:50:46 -07:00

36 lines
1.3 KiB
Haskell

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
import Control.Monad.IO.Class (liftIO)
import Control.Monad (replicateM_)
import qualified Data.Vector as V
import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (square)
import qualified TensorFlow.Minimize as TF
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
import qualified TensorFlow.Variable as TF
import Test.Framework (defaultMain, Test)
import Test.Framework.Providers.HUnit (testCase)
import TensorFlow.Test (assertAllClose)
randomParam :: TF.Shape -> TF.Session (TF.Tensor TF.Value Float)
randomParam (TF.Shape shape) = TF.truncatedNormal (TF.vector shape)
fitMatrix :: Test
fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
u <- TF.initializedVariable =<< randomParam [2, 1]
v <- TF.initializedVariable =<< randomParam [1, 2]
let ones = [1, 1, 1, 1] :: [Float]
matx = TF.constant [2, 2] ones
diff = matx `TF.sub` (TF.readValue u `TF.matMul` TF.readValue v)
loss = TF.reduceMean $ TF.square diff
trainStep <- TF.minimizeWith (TF.gradientDescent 0.01) loss [u, v]
replicateM_ 1000 (TF.run trainStep)
(u',v') <- TF.run (TF.readValue u, TF.readValue v)
-- ones = u * v
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
main :: IO ()
main = defaultMain [ fitMatrix ]