1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 03:19:44 +01:00

Switch to new TF_STRING format for TF 2.10

See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
This commit is contained in:
fkm3 2023-01-21 19:40:07 -08:00
parent b21dee98f5
commit 199f1c7663
4 changed files with 76 additions and 34 deletions

View file

@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
, setSessionConfig , setSessionConfig
, setSessionTarget , setSessionTarget
, getAllOpList , getAllOpList
, unsafeTStringToByteString
-- * Internal helper. -- * Internal helper.
, useProtoAsVoidPtrLen , useProtoAsVoidPtrLen
) )
where where
import Control.Exception (assert)
import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when) import Control.Monad (when)
@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw import qualified TensorFlow.Internal.Raw as Raw
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
-- to string into a ByteString.
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
unsafeTStringToByteString v =
assert (S.length v == Raw.sizeOfTString) $
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
let tstring = Raw.TString (castPtr tstringPtr)
p <- Raw.stringGetDataPointer tstring
n <- Raw.stringGetSize tstring
B.packCStringLen (p, fromIntegral n)
data TensorFlowException = TensorFlowException Raw.Code T.Text data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable) deriving (Show, Eq, Typeable)

View file

@ -44,6 +44,23 @@ message :: Status -> IO CString
message = {# call TF_Message as ^ #} message = {# call TF_Message as ^ #}
-- TString.
{# pointer *TF_TString as TString newtype #}
sizeOfTString :: Int
sizeOfTString = 24
-- TF_TString_Type::TF_TSTR_OFFSET
tstringOffsetTypeTag :: Word32
tstringOffsetTypeTag = 2
stringGetDataPointer :: TString -> IO CString
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #}
-- Buffer. -- Buffer.
data Buffer data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #} {# pointer *TF_Buffer as BufferPtr -> Buffer #}

View file

@ -64,6 +64,7 @@ module TensorFlow.Types
, AllTensorTypes , AllTensorTypes
) where ) where
import Data.Bits (shiftL, (.|.))
import Data.ProtoLens.Message(defMessage) import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..)) import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex) import Data.Complex (Complex)
@ -78,7 +79,6 @@ import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..), under) import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
import Lens.Family2.Unchecked (adapter) import Lens.Family2.Unchecked (adapter)
import Text.Printf (printf) import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder) import Data.ByteString.Builder (Builder)
@ -86,6 +86,7 @@ import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified Data.Vector.Storable as S import qualified Data.Vector.Storable as S
import Data.Vector.Split (chunksOf)
import Proto.Tensorflow.Core.Framework.AttrValue import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue ( AttrValue
, AttrValue'ListValue , AttrValue'ListValue
@ -126,7 +127,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
) )
import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt) import qualified TensorFlow.Internal.Raw as Raw
import qualified TensorFlow.Internal.FFI as FFI import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto type ResourceHandle = ResourceHandleProto
@ -317,50 +318,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
encodeTensorData = error "TODO (Complex Double)" encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h): -- Strings can be encoded in various ways, see [0] for an overview.
-- table offsets for each element :: [Word64] --
-- at each element offset: -- The data starts with an array of TF_TString structs (24 bytes each), one
-- string length :: VarInt64 -- for each element in the tensor. In some cases, the actual string
-- string data :: [Word8] -- contents are inlined in the TF_TString, in some cases they are in the
-- heap, in some cases they are appended to the end of the data.
--
-- When decoding, we delegate most of those details to the TString C API.
-- However, when encoding, the TString C API is prone to memory leaks given
-- the current design of tensorflow-haskell, so, instead we manually encode
-- all the strings in the "offset" format, where none of the string data is
-- stored in separate heap objects and so no destructor hook is necessary.
--
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
decodeTensorData tensorData = decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $ if S.length bytes < minBytes
if expected /= count then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
then Left $ "decodeTensorData for ByteString count mismatch " ++ show (S.length bytes) ++ ", need at least " ++ show minBytes
show (expected, count) else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
else V.mapM decodeString (S.convert offsets)
where where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64 numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes minBytes = Raw.sizeOfTString * numElements
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
encodeTensorData (Shape xs) vec = encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector TensorData $ FFI.TensorData xs dt byteVector
where where
dt = tensorType (undefined :: ByteString) dt = tensorType (undefined :: ByteString)
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
-- Add a string to an offset table and data blob. -- Add a string to an offset table and data blob.
addString :: (Builder, Builder, Word64) addString :: (Builder, Builder, Word32, Word32)
-> ByteString -> ByteString
-> (Builder, Builder, Word64) -> (Builder, Builder, Word32, Word32)
addString (table, strings, offset) str = addString (table, strings, tableOffset, stringsOffset) str =
( table <> Builder.word64LE offset ( table <> Builder.word32LE sizeField
, strings <> lengthBytes <> Builder.byteString str <> Builder.word32LE offsetField
, offset + lengthBytesLen + strLen <> Builder.word32LE capacityField
<> Builder.word32LE 0
<> Builder.word32LE 0
<> Builder.word32LE 0
, strings <> Builder.byteString str
, tableOffset + fromIntegral Raw.sizeOfTString
, stringsOffset + strLen
) )
where where
strLen = fromIntegral $ B.length str strLen :: Word32 = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str -- TF_TString.size includes a union tag in the first two bits.
lengthBytesLen = sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes -- offset is relative to the start of the TF_TString instance, so
-- we add the remaining distance to the end of the table to the
-- offset from the start of the string data.
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
capacityField :: Word32 = strLen
-- Encode all strings. -- Encode all strings.
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec (table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
-- Concat offset table with data. -- Concat offset table with data.
bytes = table' <> strings' bytes = table' <> strings'
-- Convert to Vector Word8. -- Convert to Vector Word8.

View file

@ -56,6 +56,7 @@ library
, temporary , temporary
, transformers , transformers
, vector , vector
, vector-split
extra-libraries: tensorflow extra-libraries: tensorflow
default-language: Haskell2010 default-language: Haskell2010
include-dirs: . include-dirs: .