Switch to new TF_STRING format for TF 2.10

See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
This commit is contained in:
fkm3 2023-01-21 19:40:07 -08:00
parent c5bfff9b4c
commit 63753cc20d
4 changed files with 76 additions and 34 deletions

View File

@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
, setSessionConfig
, setSessionTarget
, getAllOpList
, unsafeTStringToByteString
-- * Internal helper.
, useProtoAsVoidPtrLen
)
where
import Control.Exception (assert)
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when)
@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
-- to string into a ByteString.
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
unsafeTStringToByteString v =
assert (S.length v == Raw.sizeOfTString) $
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
let tstring = Raw.TString (castPtr tstringPtr)
p <- Raw.stringGetDataPointer tstring
n <- Raw.stringGetSize tstring
B.packCStringLen (p, fromIntegral n)
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)

View File

@ -44,6 +44,23 @@ message :: Status -> IO CString
message = {# call TF_Message as ^ #}
-- TString.
{# pointer *TF_TString as TString newtype #}
sizeOfTString :: Int
sizeOfTString = 24
-- TF_TString_Type::TF_TSTR_OFFSET
tstringOffsetTypeTag :: Word32
tstringOffsetTypeTag = 2
stringGetDataPointer :: TString -> IO CString
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #}
-- Buffer.
data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #}

View File

@ -64,6 +64,7 @@ module TensorFlow.Types
, AllTensorTypes
) where
import Data.Bits (shiftL, (.|.))
import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Data.Vector.Split (chunksOf)
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue
, AttrValue'ListValue
@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.Raw as Raw
import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto
@ -317,52 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset:
-- string length :: VarInt64
-- string data :: [Word8]
-- TODO: According to the v2.4.0 release notes, the byte layout for string
-- tensors has been changed to a contiguous array of TF_TStrings.
-- Strings can be encoded in various ways, see [0] for an overview.
--
-- The data starts with an array of TF_TString structs (24 bytes each), one
-- for each element in the tensor. In some cases, the actual string
-- contents are inlined in the TF_TString, in some cases they are in the
-- heap, in some cases they are appended to the end of the data.
--
-- When decoding, we delegate most of those details to the TString C API.
-- However, when encoding, the TString C API is prone to memory leaks given
-- the current design of tensorflow-haskell, so, instead we manually encode
-- all the strings in the "offset" format, where none of the string data is
-- stored in separate heap objects and so no destructor hook is necessary.
--
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count
then Left $ "decodeTensorData for ByteString count mismatch " ++
show (expected, count)
else V.mapM decodeString (S.convert offsets)
if S.length bytes < minBytes
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
show (S.length bytes) ++ ", need at least " ++ show minBytes
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
minBytes = Raw.sizeOfTString * numElements
encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector
where
dt = tensorType (undefined :: ByteString)
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
-- Add a string to an offset table and data blob.
addString :: (Builder, Builder, Word64)
addString :: (Builder, Builder, Word32, Word32)
-> ByteString
-> (Builder, Builder, Word64)
addString (table, strings, offset) str =
( table <> Builder.word64LE offset
, strings <> lengthBytes <> Builder.byteString str
, offset + lengthBytesLen + strLen
-> (Builder, Builder, Word32, Word32)
addString (table, strings, tableOffset, stringsOffset) str =
( table <> Builder.word32LE sizeField
<> Builder.word32LE offsetField
<> Builder.word32LE capacityField
<> Builder.word32LE 0
<> Builder.word32LE 0
<> Builder.word32LE 0
, strings <> Builder.byteString str
, tableOffset + fromIntegral Raw.sizeOfTString
, stringsOffset + strLen
)
where
strLen = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str
lengthBytesLen =
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
strLen :: Word32 = fromIntegral $ B.length str
-- TF_TString.size includes a union tag in the first two bits.
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
-- offset is relative to the start of the TF_TString instance, so
-- we add the remaining distance to the end of the table to the
-- offset from the start of the string data.
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
capacityField :: Word32 = strLen
-- Encode all strings.
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
-- Concat offset table with data.
bytes = table' <> strings'
-- Convert to Vector Word8.

View File

@ -56,6 +56,7 @@ library
, temporary
, transformers
, vector
, vector-split
extra-libraries: tensorflow
default-language: Haskell2010
include-dirs: .