Fix TensorData encode and decode for Bool (#49)

This commit is contained in:
fkm3 2016-12-12 19:40:32 -08:00 committed by GitHub
parent cc08520dc7
commit 91f508eb5c
3 changed files with 18 additions and 6 deletions

View File

@ -197,6 +197,7 @@ Test-Suite TypesTest
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-proto
, transformers

View File

@ -37,6 +37,7 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
@ -52,17 +53,22 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do
let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False]
f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3]
b <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData)
]
-- It is an error to fetch a tensor that is being fed, so the tensors
-- are passed through identity.
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
-- Do something idempotent to the tensors to verify that tensorflow can
-- handle the encoding. Originally this used `TF.identity`, but that
-- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
liftIO $ do
floatData @=? f'
stringData @=? s'
boolData @=? b'
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)

View File

@ -241,13 +241,18 @@ instance TensorType ByteString where
-- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64