2017-04-28 02:05:34 +02:00
|
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
|
|
{-# LANGUAGE OverloadedLists #-}
|
|
|
|
|
|
|
|
import Control.Monad.IO.Class (liftIO)
|
2017-05-26 04:19:22 +02:00
|
|
|
import Control.Monad (replicateM_)
|
2017-04-28 02:05:34 +02:00
|
|
|
|
|
|
|
import qualified Data.Vector as V
|
2017-05-26 04:19:22 +02:00
|
|
|
import qualified TensorFlow.Core as TF
|
|
|
|
import qualified TensorFlow.GenOps.Core as TF (square, rank)
|
|
|
|
import qualified TensorFlow.Minimize as TF
|
|
|
|
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
|
|
|
|
import qualified TensorFlow.Variable as TF
|
2017-04-28 02:05:34 +02:00
|
|
|
|
2017-05-11 00:26:03 +02:00
|
|
|
import Test.Framework (defaultMain, Test)
|
2017-04-28 02:05:34 +02:00
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
|
|
import TensorFlow.Test (assertAllClose)
|
|
|
|
|
|
|
|
randomParam :: TF.Shape -> TF.Session (TF.Tensor TF.Value Float)
|
|
|
|
randomParam (TF.Shape shape) = TF.truncatedNormal (TF.vector shape)
|
|
|
|
|
|
|
|
reduceMean :: TF.Tensor v Float -> TF.Tensor TF.Build Float
|
|
|
|
reduceMean xs = TF.mean xs (TF.range 0 (TF.rank xs) 1)
|
|
|
|
|
|
|
|
fitMatrix :: Test
|
|
|
|
fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
|
|
|
|
u <- TF.initializedVariable =<< randomParam [2, 1]
|
|
|
|
v <- TF.initializedVariable =<< randomParam [1, 2]
|
|
|
|
let ones = [1, 1, 1, 1] :: [Float]
|
|
|
|
matx = TF.constant [2, 2] ones
|
2017-05-26 04:19:22 +02:00
|
|
|
diff = matx `TF.sub` (TF.readValue u `TF.matMul` TF.readValue v)
|
2017-04-28 02:05:34 +02:00
|
|
|
loss = reduceMean $ TF.square diff
|
2017-05-26 04:19:22 +02:00
|
|
|
trainStep <- TF.minimizeWith (TF.gradientDescent 0.01) loss [u, v]
|
2017-05-09 18:54:09 +02:00
|
|
|
replicateM_ 1000 (TF.run trainStep)
|
2017-05-26 04:19:22 +02:00
|
|
|
(u',v') <- TF.run (TF.readValue u, TF.readValue v)
|
2017-04-28 02:05:34 +02:00
|
|
|
-- ones = u * v
|
|
|
|
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
|
|
|
|
|
|
|
|
main :: IO ()
|
2017-05-11 00:26:03 +02:00
|
|
|
main = defaultMain [ fitMatrix ]
|