Added einsum and test.

This commit is contained in:
jcmartin 2020-08-28 09:18:42 +00:00
parent 9cc48c3f4b
commit b30d9a52c1
2 changed files with 24 additions and 0 deletions

View File

@ -89,6 +89,8 @@ module TensorFlow.Ops
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, einsum
, einsum'
, matTranspose
, matTranspose'
, CoreOps.mean
@ -202,6 +204,13 @@ instance ( TensorType a
signum = CoreOps.sign
negate = CoreOps.neg
-- | Einstein summation
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
einsum = einsum' id
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
matTranspose = matTranspose' id

View File

@ -104,6 +104,20 @@ testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
f1 <- TF.run w
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
-- | Test einstein summation
testEinsum :: Test
testEinsum = testCase "testEinsum" $ TF.runSession $ do
-- Matrix multiply
let matA = TF.constant (TF.Shape [3,3]) [1..9 :: Float]
let matB = TF.constant (TF.Shape [3,1]) [1..3 :: Float]
matMulOut <- TF.run $ TF.matMul matA matB
einsumOut <- TF.run $ TF.einsum "ij,jk->ik" [matA,matB]
liftIO $ (matMulOut :: V.Vector Float) @=? einsumOut
-- Hadamard multiply
hadMulOut <- TF.run $ TF.mul matA matA
einsumHad <- TF.run $ TF.einsum "ij,ij->ij" [matA,matA]
liftIO $ (hadMulOut :: V.Vector Float) @=? einsumHad
main :: IO ()
main = defaultMain
[ testSaveRestore
@ -112,4 +126,5 @@ main = defaultMain
, testPlaceholderCse
, testScalarFeedCse
, testRereadRef
, testEinsum
]