mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
Switch to new TF_STRING format for TF 2.10
See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
This commit is contained in:
parent
c5bfff9b4c
commit
63753cc20d
4 changed files with 76 additions and 34 deletions
|
@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
|
||||||
, setSessionConfig
|
, setSessionConfig
|
||||||
, setSessionTarget
|
, setSessionTarget
|
||||||
, getAllOpList
|
, getAllOpList
|
||||||
|
, unsafeTStringToByteString
|
||||||
-- * Internal helper.
|
-- * Internal helper.
|
||||||
, useProtoAsVoidPtrLen
|
, useProtoAsVoidPtrLen
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
|
|
||||||
|
import Control.Exception (assert)
|
||||||
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
|
@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
||||||
|
|
||||||
import qualified TensorFlow.Internal.Raw as Raw
|
import qualified TensorFlow.Internal.Raw as Raw
|
||||||
|
|
||||||
|
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
|
||||||
|
-- to string into a ByteString.
|
||||||
|
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
|
||||||
|
unsafeTStringToByteString v =
|
||||||
|
assert (S.length v == Raw.sizeOfTString) $
|
||||||
|
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
|
||||||
|
let tstring = Raw.TString (castPtr tstringPtr)
|
||||||
|
p <- Raw.stringGetDataPointer tstring
|
||||||
|
n <- Raw.stringGetSize tstring
|
||||||
|
B.packCStringLen (p, fromIntegral n)
|
||||||
|
|
||||||
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
||||||
deriving (Show, Eq, Typeable)
|
deriving (Show, Eq, Typeable)
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,23 @@ message :: Status -> IO CString
|
||||||
message = {# call TF_Message as ^ #}
|
message = {# call TF_Message as ^ #}
|
||||||
|
|
||||||
|
|
||||||
|
-- TString.
|
||||||
|
{# pointer *TF_TString as TString newtype #}
|
||||||
|
|
||||||
|
sizeOfTString :: Int
|
||||||
|
sizeOfTString = 24
|
||||||
|
|
||||||
|
-- TF_TString_Type::TF_TSTR_OFFSET
|
||||||
|
tstringOffsetTypeTag :: Word32
|
||||||
|
tstringOffsetTypeTag = 2
|
||||||
|
|
||||||
|
stringGetDataPointer :: TString -> IO CString
|
||||||
|
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
|
||||||
|
|
||||||
|
stringGetSize :: TString -> IO CULong
|
||||||
|
stringGetSize = {# call TF_StringGetSize as ^ #}
|
||||||
|
|
||||||
|
|
||||||
-- Buffer.
|
-- Buffer.
|
||||||
data Buffer
|
data Buffer
|
||||||
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||||
|
|
|
@ -64,6 +64,7 @@ module TensorFlow.Types
|
||||||
, AllTensorTypes
|
, AllTensorTypes
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.Bits (shiftL, (.|.))
|
||||||
import Data.ProtoLens.Message(defMessage)
|
import Data.ProtoLens.Message(defMessage)
|
||||||
import Data.Functor.Identity (Identity(..))
|
import Data.Functor.Identity (Identity(..))
|
||||||
import Data.Complex (Complex)
|
import Data.Complex (Complex)
|
||||||
|
@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
|
||||||
import qualified Data.ByteString.Lazy as L
|
import qualified Data.ByteString.Lazy as L
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified Data.Vector.Storable as S
|
import qualified Data.Vector.Storable as S
|
||||||
|
import Data.Vector.Split (chunksOf)
|
||||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
( AttrValue
|
( AttrValue
|
||||||
, AttrValue'ListValue
|
, AttrValue'ListValue
|
||||||
|
@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
||||||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||||
|
|
||||||
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||||
|
import qualified TensorFlow.Internal.Raw as Raw
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
type ResourceHandle = ResourceHandleProto
|
type ResourceHandle = ResourceHandleProto
|
||||||
|
@ -317,52 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
||||||
encodeTensorData = error "TODO (Complex Double)"
|
encodeTensorData = error "TODO (Complex Double)"
|
||||||
|
|
||||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
||||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
-- Strings can be encoded in various ways, see [0] for an overview.
|
||||||
-- table offsets for each element :: [Word64]
|
--
|
||||||
-- at each element offset:
|
-- The data starts with an array of TF_TString structs (24 bytes each), one
|
||||||
-- string length :: VarInt64
|
-- for each element in the tensor. In some cases, the actual string
|
||||||
-- string data :: [Word8]
|
-- contents are inlined in the TF_TString, in some cases they are in the
|
||||||
-- TODO: According to the v2.4.0 release notes, the byte layout for string
|
-- heap, in some cases they are appended to the end of the data.
|
||||||
-- tensors has been changed to a contiguous array of TF_TStrings.
|
--
|
||||||
|
-- When decoding, we delegate most of those details to the TString C API.
|
||||||
|
-- However, when encoding, the TString C API is prone to memory leaks given
|
||||||
|
-- the current design of tensorflow-haskell, so, instead we manually encode
|
||||||
|
-- all the strings in the "offset" format, where none of the string data is
|
||||||
|
-- stored in separate heap objects and so no destructor hook is necessary.
|
||||||
|
--
|
||||||
|
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
|
||||||
decodeTensorData tensorData =
|
decodeTensorData tensorData =
|
||||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
if S.length bytes < minBytes
|
||||||
if expected /= count
|
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
|
||||||
then Left $ "decodeTensorData for ByteString count mismatch " ++
|
show (S.length bytes) ++ ", need at least " ++ show minBytes
|
||||||
show (expected, count)
|
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
|
||||||
else V.mapM decodeString (S.convert offsets)
|
|
||||||
where
|
where
|
||||||
expected = S.length offsets
|
|
||||||
count = fromIntegral $ product $ FFI.tensorDataDimensions
|
|
||||||
$ unTensorData tensorData
|
|
||||||
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
||||||
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
|
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
|
||||||
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
|
minBytes = Raw.sizeOfTString * numElements
|
||||||
decodeString :: Word64 -> Either String ByteString
|
|
||||||
decodeString offset =
|
|
||||||
let stringDataStart = B.drop (fromIntegral offset) dataBytes
|
|
||||||
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
|
|
||||||
stringParser :: Atto.Parser ByteString
|
|
||||||
stringParser = getVarInt >>= Atto.take . fromIntegral
|
|
||||||
encodeTensorData (Shape xs) vec =
|
encodeTensorData (Shape xs) vec =
|
||||||
TensorData $ FFI.TensorData xs dt byteVector
|
TensorData $ FFI.TensorData xs dt byteVector
|
||||||
where
|
where
|
||||||
dt = tensorType (undefined :: ByteString)
|
dt = tensorType (undefined :: ByteString)
|
||||||
|
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
|
||||||
-- Add a string to an offset table and data blob.
|
-- Add a string to an offset table and data blob.
|
||||||
addString :: (Builder, Builder, Word64)
|
addString :: (Builder, Builder, Word32, Word32)
|
||||||
-> ByteString
|
-> ByteString
|
||||||
-> (Builder, Builder, Word64)
|
-> (Builder, Builder, Word32, Word32)
|
||||||
addString (table, strings, offset) str =
|
addString (table, strings, tableOffset, stringsOffset) str =
|
||||||
( table <> Builder.word64LE offset
|
( table <> Builder.word32LE sizeField
|
||||||
, strings <> lengthBytes <> Builder.byteString str
|
<> Builder.word32LE offsetField
|
||||||
, offset + lengthBytesLen + strLen
|
<> Builder.word32LE capacityField
|
||||||
|
<> Builder.word32LE 0
|
||||||
|
<> Builder.word32LE 0
|
||||||
|
<> Builder.word32LE 0
|
||||||
|
, strings <> Builder.byteString str
|
||||||
|
, tableOffset + fromIntegral Raw.sizeOfTString
|
||||||
|
, stringsOffset + strLen
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
strLen = fromIntegral $ B.length str
|
strLen :: Word32 = fromIntegral $ B.length str
|
||||||
lengthBytes = putVarInt $ fromIntegral $ B.length str
|
-- TF_TString.size includes a union tag in the first two bits.
|
||||||
lengthBytesLen =
|
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
|
||||||
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
|
-- offset is relative to the start of the TF_TString instance, so
|
||||||
|
-- we add the remaining distance to the end of the table to the
|
||||||
|
-- offset from the start of the string data.
|
||||||
|
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
|
||||||
|
capacityField :: Word32 = strLen
|
||||||
-- Encode all strings.
|
-- Encode all strings.
|
||||||
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
|
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
|
||||||
-- Concat offset table with data.
|
-- Concat offset table with data.
|
||||||
bytes = table' <> strings'
|
bytes = table' <> strings'
|
||||||
-- Convert to Vector Word8.
|
-- Convert to Vector Word8.
|
||||||
|
|
|
@ -56,6 +56,7 @@ library
|
||||||
, temporary
|
, temporary
|
||||||
, transformers
|
, transformers
|
||||||
, vector
|
, vector
|
||||||
|
, vector-split
|
||||||
extra-libraries: tensorflow
|
extra-libraries: tensorflow
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
include-dirs: .
|
include-dirs: .
|
||||||
|
|
Loading…
Reference in a new issue