From b30d9a52c171a72f8984dc3f952cb39073c78a19 Mon Sep 17 00:00:00 2001 From: jcmartin Date: Fri, 28 Aug 2020 09:18:42 +0000 Subject: [PATCH] Added einsum and test. --- tensorflow-ops/src/TensorFlow/Ops.hs | 9 +++++++++ tensorflow-ops/tests/OpsTest.hs | 15 +++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 9fcad6e..384a977 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -89,6 +89,8 @@ module TensorFlow.Ops , CoreOps.identity' , CoreOps.matMul , CoreOps.matMul' + , einsum + , einsum' , matTranspose , matTranspose' , CoreOps.mean @@ -202,6 +204,13 @@ instance ( TensorType a signum = CoreOps.sign negate = CoreOps.neg +-- | Einstein summation +einsum :: TensorType t => ByteString -> [Tensor v t] -> Tensor Build t +einsum = einsum' id + +einsum' :: TensorType t => OpParams -> ByteString -> [Tensor v t] -> Tensor Build t +einsum' params equation = CoreOps.einsum' (params . (opAttr "equation" .~ equation)) + matTranspose :: TensorType a => Tensor e a -> Tensor Build a matTranspose = matTranspose' id diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 6c08c1a..b0d602d 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -104,6 +104,20 @@ testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do f1 <- TF.run w liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1) +-- | Test einstein summation +testEinsum :: Test +testEinsum = testCase "testEinsum" $ TF.runSession $ do + -- Matrix multiply + let matA = TF.constant (TF.Shape [3,3]) [1..9 :: Float] + let matB = TF.constant (TF.Shape [3,1]) [1..3 :: Float] + matMulOut <- TF.run $ TF.matMul matA matB + einsumOut <- TF.run $ TF.einsum "ij,jk->ik" [matA,matB] + liftIO $ (matMulOut :: V.Vector Float) @=? einsumOut + -- Hadamard multiply + hadMulOut <- TF.run $ TF.mul matA matA + einsumHad <- TF.run $ TF.einsum "ij,ij->ij" [matA,matA] + liftIO $ (hadMulOut :: V.Vector Float) @=? einsumHad + main :: IO () main = defaultMain [ testSaveRestore @@ -112,4 +126,5 @@ main = defaultMain , testPlaceholderCse , testScalarFeedCse , testRereadRef + , testEinsum ]