mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Switch to new TF_STRING format for TF 2.10
See https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
This commit is contained in:
parent
c5bfff9b4c
commit
63753cc20d
4 changed files with 76 additions and 34 deletions
|
@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
|
|||
, setSessionConfig
|
||||
, setSessionTarget
|
||||
, getAllOpList
|
||||
, unsafeTStringToByteString
|
||||
-- * Internal helper.
|
||||
, useProtoAsVoidPtrLen
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Exception (assert)
|
||||
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
||||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||
import Control.Monad (when)
|
||||
|
@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
|||
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
|
||||
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
|
||||
-- to string into a ByteString.
|
||||
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
|
||||
unsafeTStringToByteString v =
|
||||
assert (S.length v == Raw.sizeOfTString) $
|
||||
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
|
||||
let tstring = Raw.TString (castPtr tstringPtr)
|
||||
p <- Raw.stringGetDataPointer tstring
|
||||
n <- Raw.stringGetSize tstring
|
||||
B.packCStringLen (p, fromIntegral n)
|
||||
|
||||
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
||||
deriving (Show, Eq, Typeable)
|
||||
|
||||
|
|
|
@ -44,6 +44,23 @@ message :: Status -> IO CString
|
|||
message = {# call TF_Message as ^ #}
|
||||
|
||||
|
||||
-- TString.
|
||||
{# pointer *TF_TString as TString newtype #}
|
||||
|
||||
sizeOfTString :: Int
|
||||
sizeOfTString = 24
|
||||
|
||||
-- TF_TString_Type::TF_TSTR_OFFSET
|
||||
tstringOffsetTypeTag :: Word32
|
||||
tstringOffsetTypeTag = 2
|
||||
|
||||
stringGetDataPointer :: TString -> IO CString
|
||||
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
|
||||
|
||||
stringGetSize :: TString -> IO CULong
|
||||
stringGetSize = {# call TF_StringGetSize as ^ #}
|
||||
|
||||
|
||||
-- Buffer.
|
||||
data Buffer
|
||||
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
|
||||
|
|
|
@ -64,6 +64,7 @@ module TensorFlow.Types
|
|||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.Bits (shiftL, (.|.))
|
||||
import Data.ProtoLens.Message(defMessage)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
|
@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
|
|||
import qualified Data.ByteString.Lazy as L
|
||||
import qualified Data.Vector as V
|
||||
import qualified Data.Vector.Storable as S
|
||||
import Data.Vector.Split (chunksOf)
|
||||
import Proto.Tensorflow.Core.Framework.AttrValue
|
||||
( AttrValue
|
||||
, AttrValue'ListValue
|
||||
|
@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
|||
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
||||
|
||||
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
|
||||
import qualified TensorFlow.Internal.Raw as Raw
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
type ResourceHandle = ResourceHandleProto
|
||||
|
@ -317,52 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
|||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||
-- table offsets for each element :: [Word64]
|
||||
-- at each element offset:
|
||||
-- string length :: VarInt64
|
||||
-- string data :: [Word8]
|
||||
-- TODO: According to the v2.4.0 release notes, the byte layout for string
|
||||
-- tensors has been changed to a contiguous array of TF_TStrings.
|
||||
-- Strings can be encoded in various ways, see [0] for an overview.
|
||||
--
|
||||
-- The data starts with an array of TF_TString structs (24 bytes each), one
|
||||
-- for each element in the tensor. In some cases, the actual string
|
||||
-- contents are inlined in the TF_TString, in some cases they are in the
|
||||
-- heap, in some cases they are appended to the end of the data.
|
||||
--
|
||||
-- When decoding, we delegate most of those details to the TString C API.
|
||||
-- However, when encoding, the TString C API is prone to memory leaks given
|
||||
-- the current design of tensorflow-haskell, so, instead we manually encode
|
||||
-- all the strings in the "offset" format, where none of the string data is
|
||||
-- stored in separate heap objects and so no destructor hook is necessary.
|
||||
--
|
||||
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
|
||||
decodeTensorData tensorData =
|
||||
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
|
||||
if expected /= count
|
||||
then Left $ "decodeTensorData for ByteString count mismatch " ++
|
||||
show (expected, count)
|
||||
else V.mapM decodeString (S.convert offsets)
|
||||
if S.length bytes < minBytes
|
||||
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
|
||||
show (S.length bytes) ++ ", need at least " ++ show minBytes
|
||||
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
|
||||
where
|
||||
expected = S.length offsets
|
||||
count = fromIntegral $ product $ FFI.tensorDataDimensions
|
||||
$ unTensorData tensorData
|
||||
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
||||
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
|
||||
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
|
||||
decodeString :: Word64 -> Either String ByteString
|
||||
decodeString offset =
|
||||
let stringDataStart = B.drop (fromIntegral offset) dataBytes
|
||||
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
|
||||
stringParser :: Atto.Parser ByteString
|
||||
stringParser = getVarInt >>= Atto.take . fromIntegral
|
||||
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
|
||||
minBytes = Raw.sizeOfTString * numElements
|
||||
encodeTensorData (Shape xs) vec =
|
||||
TensorData $ FFI.TensorData xs dt byteVector
|
||||
where
|
||||
dt = tensorType (undefined :: ByteString)
|
||||
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
|
||||
-- Add a string to an offset table and data blob.
|
||||
addString :: (Builder, Builder, Word64)
|
||||
addString :: (Builder, Builder, Word32, Word32)
|
||||
-> ByteString
|
||||
-> (Builder, Builder, Word64)
|
||||
addString (table, strings, offset) str =
|
||||
( table <> Builder.word64LE offset
|
||||
, strings <> lengthBytes <> Builder.byteString str
|
||||
, offset + lengthBytesLen + strLen
|
||||
-> (Builder, Builder, Word32, Word32)
|
||||
addString (table, strings, tableOffset, stringsOffset) str =
|
||||
( table <> Builder.word32LE sizeField
|
||||
<> Builder.word32LE offsetField
|
||||
<> Builder.word32LE capacityField
|
||||
<> Builder.word32LE 0
|
||||
<> Builder.word32LE 0
|
||||
<> Builder.word32LE 0
|
||||
, strings <> Builder.byteString str
|
||||
, tableOffset + fromIntegral Raw.sizeOfTString
|
||||
, stringsOffset + strLen
|
||||
)
|
||||
where
|
||||
strLen = fromIntegral $ B.length str
|
||||
lengthBytes = putVarInt $ fromIntegral $ B.length str
|
||||
lengthBytesLen =
|
||||
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
|
||||
strLen :: Word32 = fromIntegral $ B.length str
|
||||
-- TF_TString.size includes a union tag in the first two bits.
|
||||
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
|
||||
-- offset is relative to the start of the TF_TString instance, so
|
||||
-- we add the remaining distance to the end of the table to the
|
||||
-- offset from the start of the string data.
|
||||
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
|
||||
capacityField :: Word32 = strLen
|
||||
-- Encode all strings.
|
||||
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
|
||||
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
|
||||
-- Concat offset table with data.
|
||||
bytes = table' <> strings'
|
||||
-- Convert to Vector Word8.
|
||||
|
|
|
@ -56,6 +56,7 @@ library
|
|||
, temporary
|
||||
, transformers
|
||||
, vector
|
||||
, vector-split
|
||||
extra-libraries: tensorflow
|
||||
default-language: Haskell2010
|
||||
include-dirs: .
|
||||
|
|
Loading…
Reference in a new issue