diff --git a/tensorflow-ops/src/TensorFlow/Ops.hs b/tensorflow-ops/src/TensorFlow/Ops.hs index 825bc64..e123a41 100644 --- a/tensorflow-ops/src/TensorFlow/Ops.hs +++ b/tensorflow-ops/src/TensorFlow/Ops.hs @@ -203,9 +203,15 @@ matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32]) placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a) placeholder = placeholder' id -placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a) +placeholder' :: forall m a . (MonadBuild m, TensorType a) + => OpParams -> Shape -> m (Tensor Value a) placeholder' params pShape - = render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape)) + -- Note: we don't use CoreOps.placeholder' since that op isn't stateful, + -- and thus would be CSE'd. + = build $ buildOp $ opDef "Placeholder" + & opAttr "dtype" .~ tensorType (undefined :: a) + & opAttr "shape" .~ pShape + & params -- | Creates a variable initialized to the given value. -- Initialization happens next time session runs. diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index 53b14cb..26adc39 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -12,6 +12,7 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} module Main where @@ -68,9 +69,20 @@ testSaveRestore = testCase "testSaveRestore" $ TF.run v liftIO $ TF.Scalar 134 @=? result +-- | Test that 'placeholder' is not CSE'd. +testPlaceholderCse :: Test +testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do + p1 <- TF.placeholder [] + p2 <- TF.placeholder [] + let enc :: Float -> TF.TensorData Float + enc n = TF.encodeTensorData [] (V.fromList [n]) + result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2 + liftIO $ result @=? TF.Scalar 5 + main :: IO () main = googleTest [ testSaveRestore , testSize , testReducedShape + , testPlaceholderCse ]