diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index 80d5a03..94c7478 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -197,6 +197,7 @@ Test-Suite TypesTest , lens-family , google-shim , tensorflow + , tensorflow-core-ops , tensorflow-ops , tensorflow-proto , transformers diff --git a/tensorflow-ops/tests/TypesTest.hs b/tensorflow-ops/tests/TypesTest.hs index 3364a1a..b0c8579 100644 --- a/tensorflow-ops/tests/TypesTest.hs +++ b/tensorflow-ops/tests/TypesTest.hs @@ -37,6 +37,7 @@ import qualified Data.ByteString.Char8 as B8 import qualified Data.Vector as V import qualified TensorFlow.ControlFlow as TF +import qualified TensorFlow.GenOps.Core as TF (select) import qualified TensorFlow.Ops as TF import qualified TensorFlow.Session as TF import qualified TensorFlow.Tensor as TF @@ -52,17 +53,22 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $ TF.runSession $ do let floatData = V.fromList [1..6 :: Float] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]] + boolData = V.fromList [True, True, False, True, False, False] f <- TF.build $ TF.placeholder [2,3] s <- TF.build $ TF.placeholder [2,3] + b <- TF.build $ TF.placeholder [2,3] let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) , TF.feed s (TF.encodeTensorData [2,3] stringData) + , TF.feed b (TF.encodeTensorData [2,3] boolData) ] - -- It is an error to fetch a tensor that is being fed, so the tensors - -- are passed through identity. - (f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s) + -- Do something idempotent to the tensors to verify that tensorflow can + -- handle the encoding. Originally this used `TF.identity`, but that + -- wasn't enough to catch a bug in the encoding of Bool. + (f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b) liftIO $ do floatData @=? f' stringData @=? s' + boolData @=? b' data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a) diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index f3a9b41..497942b 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -241,13 +241,18 @@ instance TensorType ByteString where -- Convert to Vector Word8. byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes - +-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes +-- encoding more expensive. It may make sense to define a custom boolean type. instance TensorType Bool where tensorType _ = DT_BOOL tensorRefType _ = DT_BOOL_REF tensorVal = boolVal - decodeTensorData = simpleDecode - encodeTensorData = simpleEncode + decodeTensorData = + S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData + encodeTensorData (Shape xs) = + TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert + where + fromBool x = if x then 1 else 0 :: Word8 instance TensorType (Complex Float) where tensorType _ = DT_COMPLEX64