mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Fix TensorData encode and decode for Bool (#49)
This commit is contained in:
parent
cc08520dc7
commit
91f508eb5c
3 changed files with 18 additions and 6 deletions
|
@ -197,6 +197,7 @@ Test-Suite TypesTest
|
||||||
, lens-family
|
, lens-family
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, transformers
|
, transformers
|
||||||
|
|
|
@ -37,6 +37,7 @@ import qualified Data.ByteString.Char8 as B8
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as TF (select)
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
|
@ -52,17 +53,22 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
let floatData = V.fromList [1..6 :: Float]
|
let floatData = V.fromList [1..6 :: Float]
|
||||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||||
|
boolData = V.fromList [True, True, False, True, False, False]
|
||||||
f <- TF.build $ TF.placeholder [2,3]
|
f <- TF.build $ TF.placeholder [2,3]
|
||||||
s <- TF.build $ TF.placeholder [2,3]
|
s <- TF.build $ TF.placeholder [2,3]
|
||||||
|
b <- TF.build $ TF.placeholder [2,3]
|
||||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||||
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||||
|
, TF.feed b (TF.encodeTensorData [2,3] boolData)
|
||||||
]
|
]
|
||||||
-- It is an error to fetch a tensor that is being fed, so the tensors
|
-- Do something idempotent to the tensors to verify that tensorflow can
|
||||||
-- are passed through identity.
|
-- handle the encoding. Originally this used `TF.identity`, but that
|
||||||
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
|
-- wasn't enough to catch a bug in the encoding of Bool.
|
||||||
|
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
|
||||||
liftIO $ do
|
liftIO $ do
|
||||||
floatData @=? f'
|
floatData @=? f'
|
||||||
stringData @=? s'
|
stringData @=? s'
|
||||||
|
boolData @=? b'
|
||||||
|
|
||||||
|
|
||||||
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
||||||
|
|
|
@ -241,13 +241,18 @@ instance TensorType ByteString where
|
||||||
-- Convert to Vector Word8.
|
-- Convert to Vector Word8.
|
||||||
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||||
|
|
||||||
|
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||||
|
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||||
instance TensorType Bool where
|
instance TensorType Bool where
|
||||||
tensorType _ = DT_BOOL
|
tensorType _ = DT_BOOL
|
||||||
tensorRefType _ = DT_BOOL_REF
|
tensorRefType _ = DT_BOOL_REF
|
||||||
tensorVal = boolVal
|
tensorVal = boolVal
|
||||||
decodeTensorData = simpleDecode
|
decodeTensorData =
|
||||||
encodeTensorData = simpleEncode
|
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
||||||
|
encodeTensorData (Shape xs) =
|
||||||
|
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
||||||
|
where
|
||||||
|
fromBool x = if x then 1 else 0 :: Word8
|
||||||
|
|
||||||
instance TensorType (Complex Float) where
|
instance TensorType (Complex Float) where
|
||||||
tensorType _ = DT_COMPLEX64
|
tensorType _ = DT_COMPLEX64
|
||||||
|
|
Loading…
Reference in a new issue