From 63753cc20da7dbe8668106496f783923ccd6c370 Mon Sep 17 00:00:00 2001 From: fkm3 Date: Sat, 21 Jan 2023 19:40:07 -0800 Subject: [PATCH] Switch to new TF_STRING format for TF 2.10 See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md --- tensorflow/src/TensorFlow/Internal/FFI.hs | 13 ++++ tensorflow/src/TensorFlow/Internal/Raw.chs | 17 +++++ tensorflow/src/TensorFlow/Types.hs | 79 ++++++++++++---------- tensorflow/tensorflow.cabal | 1 + 4 files changed, 76 insertions(+), 34 deletions(-) diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 2e29f68..3a840e8 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI , setSessionConfig , setSessionTarget , getAllOpList + , unsafeTStringToByteString -- * Internal helper. , useProtoAsVoidPtrLen ) where +import Control.Exception (assert) import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Monad (when) @@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import qualified TensorFlow.Internal.Raw as Raw +-- Interpret a vector of bytes as a TF_TString struct and copy the pointed +-- to string into a ByteString. +unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString +unsafeTStringToByteString v = + assert (S.length v == Raw.sizeOfTString) $ + unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do + let tstring = Raw.TString (castPtr tstringPtr) + p <- Raw.stringGetDataPointer tstring + n <- Raw.stringGetSize tstring + B.packCStringLen (p, fromIntegral n) + data TensorFlowException = TensorFlowException Raw.Code T.Text deriving (Show, Eq, Typeable) diff --git a/tensorflow/src/TensorFlow/Internal/Raw.chs b/tensorflow/src/TensorFlow/Internal/Raw.chs index 8e2c430..f1fa7a0 100644 --- a/tensorflow/src/TensorFlow/Internal/Raw.chs +++ b/tensorflow/src/TensorFlow/Internal/Raw.chs @@ -44,6 +44,23 @@ message :: Status -> IO CString message = {# call TF_Message as ^ #} +-- TString. +{# pointer *TF_TString as TString newtype #} + +sizeOfTString :: Int +sizeOfTString = 24 + +-- TF_TString_Type::TF_TSTR_OFFSET +tstringOffsetTypeTag :: Word32 +tstringOffsetTypeTag = 2 + +stringGetDataPointer :: TString -> IO CString +stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #} + +stringGetSize :: TString -> IO CULong +stringGetSize = {# call TF_StringGetSize as ^ #} + + -- Buffer. data Buffer {# pointer *TF_Buffer as BufferPtr -> Buffer #} diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 714fbfa..c403cde 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -64,6 +64,7 @@ module TensorFlow.Types , AllTensorTypes ) where +import Data.Bits (shiftL, (.|.)) import Data.ProtoLens.Message(defMessage) import Data.Functor.Identity (Identity(..)) import Data.Complex (Complex) @@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder import qualified Data.ByteString.Lazy as L import qualified Data.Vector as V import qualified Data.Vector.Storable as S +import Data.Vector.Split (chunksOf) import Proto.Tensorflow.Core.Framework.AttrValue ( AttrValue , AttrValue'ListValue @@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import TensorFlow.Internal.VarInt (getVarInt, putVarInt) +import qualified TensorFlow.Internal.Raw as Raw import qualified TensorFlow.Internal.FFI as FFI type ResourceHandle = ResourceHandleProto @@ -317,52 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where encodeTensorData = error "TODO (Complex Double)" instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where - -- Encoded data layout (described in third_party/tensorflow/c/c_api.h): - -- table offsets for each element :: [Word64] - -- at each element offset: - -- string length :: VarInt64 - -- string data :: [Word8] - -- TODO: According to the v2.4.0 release notes, the byte layout for string - -- tensors has been changed to a contiguous array of TF_TStrings. + -- Strings can be encoded in various ways, see [0] for an overview. + -- + -- The data starts with an array of TF_TString structs (24 bytes each), one + -- for each element in the tensor. In some cases, the actual string + -- contents are inlined in the TF_TString, in some cases they are in the + -- heap, in some cases they are appended to the end of the data. + -- + -- When decoding, we delegate most of those details to the TString C API. + -- However, when encoding, the TString C API is prone to memory leaks given + -- the current design of tensorflow-haskell, so, instead we manually encode + -- all the strings in the "offset" format, where none of the string data is + -- stored in separate heap objects and so no destructor hook is necessary. + -- + -- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md decodeTensorData tensorData = - either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ - if expected /= count - then Left $ "decodeTensorData for ByteString count mismatch " ++ - show (expected, count) - else V.mapM decodeString (S.convert offsets) + if S.length bytes < minBytes + then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++ + show (S.length bytes) ++ ", need at least " ++ show minBytes + else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes)) where - expected = S.length offsets - count = fromIntegral $ product $ FFI.tensorDataDimensions - $ unTensorData tensorData bytes = FFI.tensorDataBytes $ unTensorData tensorData - offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64 - dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes - decodeString :: Word64 -> Either String ByteString - decodeString offset = - let stringDataStart = B.drop (fromIntegral offset) dataBytes - in Atto.eitherResult $ Atto.parse stringParser stringDataStart - stringParser :: Atto.Parser ByteString - stringParser = getVarInt >>= Atto.take . fromIntegral + numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData + minBytes = Raw.sizeOfTString * numElements encodeTensorData (Shape xs) vec = TensorData $ FFI.TensorData xs dt byteVector where dt = tensorType (undefined :: ByteString) + tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec) -- Add a string to an offset table and data blob. - addString :: (Builder, Builder, Word64) + addString :: (Builder, Builder, Word32, Word32) -> ByteString - -> (Builder, Builder, Word64) - addString (table, strings, offset) str = - ( table <> Builder.word64LE offset - , strings <> lengthBytes <> Builder.byteString str - , offset + lengthBytesLen + strLen + -> (Builder, Builder, Word32, Word32) + addString (table, strings, tableOffset, stringsOffset) str = + ( table <> Builder.word32LE sizeField + <> Builder.word32LE offsetField + <> Builder.word32LE capacityField + <> Builder.word32LE 0 + <> Builder.word32LE 0 + <> Builder.word32LE 0 + , strings <> Builder.byteString str + , tableOffset + fromIntegral Raw.sizeOfTString + , stringsOffset + strLen ) where - strLen = fromIntegral $ B.length str - lengthBytes = putVarInt $ fromIntegral $ B.length str - lengthBytesLen = - fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes + strLen :: Word32 = fromIntegral $ B.length str + -- TF_TString.size includes a union tag in the first two bits. + sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag + -- offset is relative to the start of the TF_TString instance, so + -- we add the remaining distance to the end of the table to the + -- offset from the start of the string data. + offsetField :: Word32 = tableSize - tableOffset + stringsOffset + capacityField :: Word32 = strLen -- Encode all strings. - (table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec + (table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec -- Concat offset table with data. bytes = table' <> strings' -- Convert to Vector Word8. diff --git a/tensorflow/tensorflow.cabal b/tensorflow/tensorflow.cabal index 40bbe6c..06d2adc 100644 --- a/tensorflow/tensorflow.cabal +++ b/tensorflow/tensorflow.cabal @@ -56,6 +56,7 @@ library , temporary , transformers , vector + , vector-split extra-libraries: tensorflow default-language: Haskell2010 include-dirs: .