mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
parent
c99a23b6a7
commit
fdbfd050f8
2 changed files with 20 additions and 2 deletions
|
@ -203,9 +203,15 @@ matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
|||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
placeholder = placeholder' id
|
||||
|
||||
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
|
||||
placeholder' :: forall m a . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Shape -> m (Tensor Value a)
|
||||
placeholder' params pShape
|
||||
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
|
||||
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
||||
-- and thus would be CSE'd.
|
||||
= build $ buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ pShape
|
||||
& params
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Main where
|
||||
|
@ -68,9 +69,20 @@ testSaveRestore = testCase "testSaveRestore" $
|
|||
TF.run v
|
||||
liftIO $ TF.Scalar 134 @=? result
|
||||
|
||||
-- | Test that 'placeholder' is not CSE'd.
|
||||
testPlaceholderCse :: Test
|
||||
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
||||
p1 <- TF.placeholder []
|
||||
p2 <- TF.placeholder []
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSaveRestore
|
||||
, testSize
|
||||
, testReducedShape
|
||||
, testPlaceholderCse
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue