Added einsum and test.
This commit is contained in:
parent
9cc48c3f4b
commit
b30d9a52c1
|
@ -89,6 +89,8 @@ module TensorFlow.Ops
|
|||
, CoreOps.identity'
|
||||
, CoreOps.matMul
|
||||
, CoreOps.matMul'
|
||||
, einsum
|
||||
, einsum'
|
||||
, matTranspose
|
||||
, matTranspose'
|
||||
, CoreOps.mean
|
||||
|
@ -202,6 +204,13 @@ instance ( TensorType a
|
|||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
-- | Einstein summation
|
||||
einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t
|
||||
einsum = einsum' id
|
||||
|
||||
einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t
|
||||
einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation))
|
||||
|
||||
matTranspose :: TensorType a => Tensor e a -> Tensor Build a
|
||||
matTranspose = matTranspose' id
|
||||
|
||||
|
|
|
@ -104,6 +104,20 @@ testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
|
|||
f1 <- TF.run w
|
||||
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
|
||||
|
||||
-- | Test einstein summation
|
||||
testEinsum :: Test
|
||||
testEinsum = testCase "testEinsum" $ TF.runSession $ do
|
||||
-- Matrix multiply
|
||||
let matA = TF.constant (TF.Shape [3,3]) [1..9 :: Float]
|
||||
let matB = TF.constant (TF.Shape [3,1]) [1..3 :: Float]
|
||||
matMulOut <- TF.run $ TF.matMul matA matB
|
||||
einsumOut <- TF.run $ TF.einsum "ij,jk->ik" [matA,matB]
|
||||
liftIO $ (matMulOut :: V.Vector Float) @=? einsumOut
|
||||
-- Hadamard multiply
|
||||
hadMulOut <- TF.run $ TF.mul matA matA
|
||||
einsumHad <- TF.run $ TF.einsum "ij,ij->ij" [matA,matA]
|
||||
liftIO $ (hadMulOut :: V.Vector Float) @=? einsumHad
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ testSaveRestore
|
||||
|
@ -112,4 +126,5 @@ main = defaultMain
|
|||
, testPlaceholderCse
|
||||
, testScalarFeedCse
|
||||
, testRereadRef
|
||||
, testEinsum
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue