mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
added matrix factorization test (#101)
This commit is contained in:
parent
51c883684b
commit
09c792b84c
2 changed files with 66 additions and 0 deletions
|
@ -45,6 +45,24 @@ Test-Suite RegressionTest
|
||||||
, tensorflow-core-ops
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
|
|
||||||
|
Test-Suite MatrixTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: MatrixTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: base
|
||||||
|
, HUnit
|
||||||
|
, random
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, tensorflow-test
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
Test-Suite BuildTest
|
Test-Suite BuildTest
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
|
|
48
tensorflow-ops/tests/MatrixTest.hs
Normal file
48
tensorflow-ops/tests/MatrixTest.hs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import Control.Monad (replicateM_, zipWithM)
|
||||||
|
|
||||||
|
import qualified TensorFlow.GenOps.Core as TF (square, rank)
|
||||||
|
import qualified TensorFlow.Core as TF
|
||||||
|
import qualified TensorFlow.Gradient as TF
|
||||||
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import Test.Framework (Test)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import TensorFlow.Test (assertAllClose)
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
|
||||||
|
randomParam :: TF.Shape -> TF.Session (TF.Tensor TF.Value Float)
|
||||||
|
randomParam (TF.Shape shape) = TF.truncatedNormal (TF.vector shape)
|
||||||
|
|
||||||
|
reduceMean :: TF.Tensor v Float -> TF.Tensor TF.Build Float
|
||||||
|
reduceMean xs = TF.mean xs (TF.range 0 (TF.rank xs) 1)
|
||||||
|
|
||||||
|
fitMatrix :: Test
|
||||||
|
fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
|
||||||
|
u <- TF.initializedVariable =<< randomParam [2, 1]
|
||||||
|
v <- TF.initializedVariable =<< randomParam [1, 2]
|
||||||
|
let ones = [1, 1, 1, 1] :: [Float]
|
||||||
|
matx = TF.constant [2, 2] ones
|
||||||
|
diff = matx `TF.sub` (u `TF.matMul` v)
|
||||||
|
loss = reduceMean $ TF.square diff
|
||||||
|
trainStep <- gradientDescent 0.01 loss [u, v]
|
||||||
|
replicateM_ 300 (TF.run trainStep)
|
||||||
|
(u',v') <- TF.run (u, v)
|
||||||
|
-- ones = u * v
|
||||||
|
liftIO $ assertAllClose (V.fromList ones) ((*) <$> u' <*> v')
|
||||||
|
|
||||||
|
gradientDescent :: Float
|
||||||
|
-> TF.Tensor TF.Build Float
|
||||||
|
-> [TF.Tensor TF.Ref Float]
|
||||||
|
-> TF.Session TF.ControlNode
|
||||||
|
gradientDescent alpha loss params = do
|
||||||
|
let applyGrad param grad =
|
||||||
|
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||||
|
TF.group =<< zipWithM applyGrad params =<< TF.gradients loss params
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ fitMatrix ]
|
Loading…
Reference in a new issue