Switch to new TF_STRING format for TF 2.10

See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
This commit is contained in:
fkm3 2023-01-21 19:40:07 -08:00
parent c5bfff9b4c
commit 63753cc20d
4 changed files with 76 additions and 34 deletions

View File

@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
, setSessionConfig , setSessionConfig
, setSessionTarget , setSessionTarget
, getAllOpList , getAllOpList
, unsafeTStringToByteString
-- * Internal helper. -- * Internal helper.
, useProtoAsVoidPtrLen , useProtoAsVoidPtrLen
) )
where where
import Control.Exception (assert)
import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when) import Control.Monad (when)
@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw import qualified TensorFlow.Internal.Raw as Raw
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
-- to string into a ByteString.
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
unsafeTStringToByteString v =
assert (S.length v == Raw.sizeOfTString) $
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
let tstring = Raw.TString (castPtr tstringPtr)
p <- Raw.stringGetDataPointer tstring
n <- Raw.stringGetSize tstring
B.packCStringLen (p, fromIntegral n)
data TensorFlowException = TensorFlowException Raw.Code T.Text data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable) deriving (Show, Eq, Typeable)

View File

@ -44,6 +44,23 @@ message :: Status -> IO CString
message = {# call TF_Message as ^ #} message = {# call TF_Message as ^ #}
-- TString.
{# pointer *TF_TString as TString newtype #}
sizeOfTString :: Int
sizeOfTString = 24
-- TF_TString_Type::TF_TSTR_OFFSET
tstringOffsetTypeTag :: Word32
tstringOffsetTypeTag = 2
stringGetDataPointer :: TString -> IO CString
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #}
-- Buffer. -- Buffer.
data Buffer data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #} {# pointer *TF_Buffer as BufferPtr -> Buffer #}

View File

@ -64,6 +64,7 @@ module TensorFlow.Types
, AllTensorTypes , AllTensorTypes
) where ) where
import Data.Bits (shiftL, (.|.))
import Data.ProtoLens.Message(defMessage) import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..)) import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex) import Data.Complex (Complex)
@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified Data.Vector.Storable as S import qualified Data.Vector.Storable as S
import Data.Vector.Split (chunksOf)
import Proto.Tensorflow.Core.Framework.AttrValue import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue ( AttrValue
, AttrValue'ListValue , AttrValue'ListValue
@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt) import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.Raw as Raw
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto type ResourceHandle = ResourceHandleProto
@ -317,52 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
encodeTensorData = error "TODO (Complex Double)" encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- Strings can be encoded in various ways, see [0] for an overview.
-- table offsets for each element :: [Word64] --
-- at each element offset: -- The data starts with an array of TF_TString structs (24 bytes each), one
-- string length :: VarInt64 -- for each element in the tensor. In some cases, the actual string
-- string data :: [Word8] -- contents are inlined in the TF_TString, in some cases they are in the
-- TODO: According to the v2.4.0 release notes, the byte layout for string -- heap, in some cases they are appended to the end of the data.
-- tensors has been changed to a contiguous array of TF_TStrings. --
-- When decoding, we delegate most of those details to the TString C API.
-- However, when encoding, the TString C API is prone to memory leaks given
-- the current design of tensorflow-haskell, so, instead we manually encode
-- all the strings in the "offset" format, where none of the string data is
-- stored in separate heap objects and so no destructor hook is necessary.
--
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
decodeTensorData tensorData = decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ if S.length bytes < minBytes
if expected /= count then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
then Left $ "decodeTensorData for ByteString count mismatch " ++ show (S.length bytes) ++ ", need at least " ++ show minBytes
show (expected, count) else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
else V.mapM decodeString (S.convert offsets)
where where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64 numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes minBytes = Raw.sizeOfTString * numElements
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
encodeTensorData (Shape xs) vec = encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector TensorData $ FFI.TensorData xs dt byteVector
where where
dt = tensorType (undefined :: ByteString) dt = tensorType (undefined :: ByteString)
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
-- Add a string to an offset table and data blob. -- Add a string to an offset table and data blob.
addString :: (Builder, Builder, Word64) addString :: (Builder, Builder, Word32, Word32)
-> ByteString -> ByteString
-> (Builder, Builder, Word64) -> (Builder, Builder, Word32, Word32)
addString (table, strings, offset) str = addString (table, strings, tableOffset, stringsOffset) str =
( table <> Builder.word64LE offset ( table <> Builder.word32LE sizeField
, strings <> lengthBytes <> Builder.byteString str <> Builder.word32LE offsetField
, offset + lengthBytesLen + strLen <> Builder.word32LE capacityField
<> Builder.word32LE 0
<> Builder.word32LE 0
<> Builder.word32LE 0
, strings <> Builder.byteString str
, tableOffset + fromIntegral Raw.sizeOfTString
, stringsOffset + strLen
) )
where where
strLen = fromIntegral $ B.length str strLen :: Word32 = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str -- TF_TString.size includes a union tag in the first two bits.
lengthBytesLen = sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes -- offset is relative to the start of the TF_TString instance, so
-- we add the remaining distance to the end of the table to the
-- offset from the start of the string data.
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
capacityField :: Word32 = strLen
-- Encode all strings. -- Encode all strings.
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec (table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
-- Concat offset table with data. -- Concat offset table with data.
bytes = table' <> strings' bytes = table' <> strings'
-- Convert to Vector Word8. -- Convert to Vector Word8.

View File

@ -56,6 +56,7 @@ library
, temporary , temporary
, transformers , transformers
, vector , vector
, vector-split
extra-libraries: tensorflow extra-libraries: tensorflow
default-language: Haskell2010 default-language: Haskell2010
include-dirs: . include-dirs: .