mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
6b19e54722
* Update README to refer to 2.3.0-gpu. * Remove old package documentation from haddock directory.
985 lines
45 KiB
Text
985 lines
45 KiB
Text
-- Hoogle documentation, generated by Haddock
|
|
-- See Hoogle, http://www.haskell.org/hoogle/
|
|
|
|
|
|
-- | TensorFlow bindings.
|
|
--
|
|
-- This library provides an interface to the TensorFlow bindings.
|
|
-- <a>TensorFlow.Core</a> contains the base API for building and running
|
|
-- computational graphs. Other packages such as <tt>tensorflow-ops</tt>
|
|
-- contain bindings to the actual computational kernels.
|
|
--
|
|
-- For more documentation and examples, see
|
|
-- <a>https://github.com/tensorflow/haskell#readme</a>
|
|
@package tensorflow
|
|
@version 0.3.0.0
|
|
|
|
module TensorFlow.Internal.FFI
|
|
data TensorFlowException
|
|
TensorFlowException :: Code -> Text -> TensorFlowException
|
|
data Session
|
|
|
|
-- | Runs the given action after creating a session with options populated
|
|
-- by the given optionSetter.
|
|
withSession :: (MonadIO m, MonadMask m) => (SessionOptions -> IO ()) -> ((IO () -> IO ()) -> Session -> m a) -> m a
|
|
extendGraph :: Session -> GraphDef -> IO ()
|
|
run :: Session -> [(ByteString, TensorData)] -> [ByteString] -> [ByteString] -> IO [TensorData]
|
|
|
|
-- | All of the data needed to represent a tensor.
|
|
data TensorData
|
|
TensorData :: [Int64] -> !DataType -> !Vector Word8 -> TensorData
|
|
[tensorDataDimensions] :: TensorData -> [Int64]
|
|
[tensorDataType] :: TensorData -> !DataType
|
|
[tensorDataBytes] :: TensorData -> !Vector Word8
|
|
setSessionConfig :: ConfigProto -> SessionOptions -> IO ()
|
|
setSessionTarget :: ByteString -> SessionOptions -> IO ()
|
|
|
|
-- | Returns the serialized OpList of all OpDefs defined in this address
|
|
-- space.
|
|
getAllOpList :: IO ByteString
|
|
|
|
-- | Serializes the given msg and provides it as (ptr,len) argument to the
|
|
-- given action.
|
|
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) => msg -> (Ptr b -> c -> IO a) -> IO a
|
|
instance GHC.Classes.Eq TensorFlow.Internal.FFI.TensorData
|
|
instance GHC.Show.Show TensorFlow.Internal.FFI.TensorData
|
|
instance GHC.Classes.Eq TensorFlow.Internal.FFI.TensorFlowException
|
|
instance GHC.Show.Show TensorFlow.Internal.FFI.TensorFlowException
|
|
instance GHC.Exception.Type.Exception TensorFlow.Internal.FFI.TensorFlowException
|
|
|
|
|
|
-- | Originally taken from internal proto-lens code.
|
|
module TensorFlow.Internal.VarInt
|
|
|
|
-- | Decode an unsigned varint.
|
|
getVarInt :: Parser Word64
|
|
|
|
-- | Encode a Word64.
|
|
putVarInt :: Word64 -> Builder
|
|
|
|
module TensorFlow.Types
|
|
|
|
-- | The class of scalar types supported by tensorflow.
|
|
class TensorType a
|
|
tensorType :: TensorType a => a -> DataType
|
|
tensorRefType :: TensorType a => a -> DataType
|
|
tensorVal :: TensorType a => Lens' TensorProto [a]
|
|
|
|
-- | Tensor data with the correct memory layout for tensorflow.
|
|
newtype TensorData a
|
|
TensorData :: TensorData -> TensorData a
|
|
[unTensorData] :: TensorData a -> TensorData
|
|
|
|
-- | Types that can be converted to and from <a>TensorData</a>.
|
|
--
|
|
-- <a>Vector</a> is the most efficient to encode/decode for most element
|
|
-- types.
|
|
class TensorType a => TensorDataType s a
|
|
|
|
-- | Decode the bytes of a <a>TensorData</a> into an <a>s</a>.
|
|
decodeTensorData :: TensorDataType s a => TensorData a -> s a
|
|
|
|
-- | Encode an <a>s</a> into a <a>TensorData</a>.
|
|
--
|
|
-- The values should be in row major order, e.g.,
|
|
--
|
|
-- element 0: index (0, ..., 0) element 1: index (0, ..., 1) ...
|
|
encodeTensorData :: TensorDataType s a => Shape -> s a -> TensorData a
|
|
newtype Scalar a
|
|
Scalar :: a -> Scalar a
|
|
[unScalar] :: Scalar a -> a
|
|
|
|
-- | Shape (dimensions) of a tensor.
|
|
--
|
|
-- TensorFlow supports shapes of unknown rank, which are represented as
|
|
-- <tt>Nothing :: Maybe Shape</tt> in Haskell.
|
|
newtype Shape
|
|
Shape :: [Int64] -> Shape
|
|
protoShape :: Lens' TensorShapeProto Shape
|
|
class Attribute a
|
|
attrLens :: Attribute a => Lens' AttrValue a
|
|
data DataType
|
|
DT_INVALID :: DataType
|
|
DT_FLOAT :: DataType
|
|
DT_DOUBLE :: DataType
|
|
DT_INT32 :: DataType
|
|
DT_UINT8 :: DataType
|
|
DT_INT16 :: DataType
|
|
DT_INT8 :: DataType
|
|
DT_STRING :: DataType
|
|
DT_COMPLEX64 :: DataType
|
|
DT_INT64 :: DataType
|
|
DT_BOOL :: DataType
|
|
DT_QINT8 :: DataType
|
|
DT_QUINT8 :: DataType
|
|
DT_QINT32 :: DataType
|
|
DT_BFLOAT16 :: DataType
|
|
DT_QINT16 :: DataType
|
|
DT_QUINT16 :: DataType
|
|
DT_UINT16 :: DataType
|
|
DT_COMPLEX128 :: DataType
|
|
DT_HALF :: DataType
|
|
DT_RESOURCE :: DataType
|
|
DT_VARIANT :: DataType
|
|
DT_UINT32 :: DataType
|
|
DT_UINT64 :: DataType
|
|
DT_FLOAT_REF :: DataType
|
|
DT_DOUBLE_REF :: DataType
|
|
DT_INT32_REF :: DataType
|
|
DT_UINT8_REF :: DataType
|
|
DT_INT16_REF :: DataType
|
|
DT_INT8_REF :: DataType
|
|
DT_STRING_REF :: DataType
|
|
DT_COMPLEX64_REF :: DataType
|
|
DT_INT64_REF :: DataType
|
|
DT_BOOL_REF :: DataType
|
|
DT_QINT8_REF :: DataType
|
|
DT_QUINT8_REF :: DataType
|
|
DT_QINT32_REF :: DataType
|
|
DT_BFLOAT16_REF :: DataType
|
|
DT_QINT16_REF :: DataType
|
|
DT_QUINT16_REF :: DataType
|
|
DT_UINT16_REF :: DataType
|
|
DT_COMPLEX128_REF :: DataType
|
|
DT_HALF_REF :: DataType
|
|
DT_RESOURCE_REF :: DataType
|
|
DT_VARIANT_REF :: DataType
|
|
DT_UINT32_REF :: DataType
|
|
DT_UINT64_REF :: DataType
|
|
DataType'Unrecognized :: !DataType'UnrecognizedValue -> DataType
|
|
type ResourceHandle = ResourceHandleProto
|
|
|
|
-- | Dynamic type. TensorFlow variants aren't supported yet. This type acts
|
|
-- a placeholder to simplify op generation.
|
|
data Variant
|
|
|
|
-- | A heterogeneous list type.
|
|
data ListOf f as
|
|
[Nil] :: ListOf f '[]
|
|
[:/] :: f a -> ListOf f as -> ListOf f (a : as)
|
|
infixr 5 :/
|
|
type List = ListOf Identity
|
|
|
|
-- | Equivalent of <a>:/</a> for lists.
|
|
(/:/) :: a -> List as -> List (a : as)
|
|
infixr 5 /:/
|
|
data TensorTypeProxy a
|
|
[TensorTypeProxy] :: TensorType a => TensorTypeProxy a
|
|
class TensorTypes (ts :: [*])
|
|
tensorTypes :: TensorTypes ts => TensorTypeList ts
|
|
type TensorTypeList = ListOf TensorTypeProxy
|
|
fromTensorTypeList :: TensorTypeList ts -> [DataType]
|
|
fromTensorTypes :: forall as. TensorTypes as => Proxy as -> [DataType]
|
|
|
|
-- | A <a>Constraint</a> specifying the possible choices of a
|
|
-- <a>TensorType</a>.
|
|
--
|
|
-- We implement a <a>Constraint</a> like <tt>OneOf '[Double, Float]
|
|
-- a</tt> by turning the natural representation as a conjunction, i.e.,
|
|
--
|
|
-- <pre>
|
|
-- a == Double || a == Float
|
|
-- </pre>
|
|
--
|
|
-- into a disjunction like
|
|
--
|
|
-- <pre>
|
|
-- a /= Int32 && a /= Int64 && a /= ByteString && ...
|
|
-- </pre>
|
|
--
|
|
-- using an enumeration of all the possible <a>TensorType</a>s.
|
|
type OneOf ts a = (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)
|
|
|
|
-- | A constraint checking that two types are different.
|
|
type family a /= b :: Constraint
|
|
type OneOfs ts as = (TensorTypes as, TensorTypes' ts, NoneOfs (AllTensorTypes \\ ts) as)
|
|
|
|
-- | Helper types to produce a reasonable type error message when the
|
|
-- Constraint "a /= a" fails. TODO(judahjacobson): Use ghc-8's
|
|
-- CustomTypeErrors for this.
|
|
data TypeError a
|
|
data ExcludedCase
|
|
|
|
-- | A constraint that the type <tt>a</tt> doesn't appear in the type list
|
|
-- <tt>ts</tt>. Assumes that <tt>a</tt> and each of the elements of
|
|
-- <tt>ts</tt> are <a>TensorType</a>s.
|
|
type family NoneOf ts a :: Constraint
|
|
|
|
-- | Takes the difference of two lists of types.
|
|
type family as \\ bs
|
|
|
|
-- | Removes a type from the given list of types.
|
|
type family Delete a as
|
|
|
|
-- | An enumeration of all valid <a>TensorType</a>s.
|
|
type AllTensorTypes = '[Float, Double, Int8, Int16, Int32, Int64, Word8, Word16, ByteString, Bool]
|
|
instance GHC.Show.Show TensorFlow.Types.Shape
|
|
instance Data.String.IsString a => Data.String.IsString (TensorFlow.Types.Scalar a)
|
|
instance GHC.Real.RealFrac a => GHC.Real.RealFrac (TensorFlow.Types.Scalar a)
|
|
instance GHC.Float.RealFloat a => GHC.Float.RealFloat (TensorFlow.Types.Scalar a)
|
|
instance GHC.Real.Real a => GHC.Real.Real (TensorFlow.Types.Scalar a)
|
|
instance GHC.Float.Floating a => GHC.Float.Floating (TensorFlow.Types.Scalar a)
|
|
instance GHC.Real.Fractional a => GHC.Real.Fractional (TensorFlow.Types.Scalar a)
|
|
instance GHC.Num.Num a => GHC.Num.Num (TensorFlow.Types.Scalar a)
|
|
instance GHC.Classes.Ord a => GHC.Classes.Ord (TensorFlow.Types.Scalar a)
|
|
instance GHC.Classes.Eq a => GHC.Classes.Eq (TensorFlow.Types.Scalar a)
|
|
instance GHC.Show.Show a => GHC.Show.Show (TensorFlow.Types.Scalar a)
|
|
instance TensorFlow.Types.TensorTypes '[]
|
|
instance (TensorFlow.Types.TensorType t, TensorFlow.Types.TensorTypes ts) => TensorFlow.Types.TensorTypes (t : ts)
|
|
instance TensorFlow.Types.All GHC.Classes.Eq (TensorFlow.Types.Map f as) => GHC.Classes.Eq (TensorFlow.Types.ListOf f as)
|
|
instance TensorFlow.Types.All GHC.Show.Show (TensorFlow.Types.Map f as) => GHC.Show.Show (TensorFlow.Types.ListOf f as)
|
|
instance TensorFlow.Types.Attribute GHC.Types.Float
|
|
instance TensorFlow.Types.Attribute Data.ByteString.Internal.ByteString
|
|
instance TensorFlow.Types.Attribute GHC.Int.Int64
|
|
instance TensorFlow.Types.Attribute Proto.Tensorflow.Core.Framework.Types.DataType
|
|
instance TensorFlow.Types.Attribute Proto.Tensorflow.Core.Framework.Tensor.TensorProto
|
|
instance TensorFlow.Types.Attribute GHC.Types.Bool
|
|
instance TensorFlow.Types.Attribute TensorFlow.Types.Shape
|
|
instance TensorFlow.Types.Attribute (GHC.Maybe.Maybe TensorFlow.Types.Shape)
|
|
instance TensorFlow.Types.Attribute Proto.Tensorflow.Core.Framework.AttrValue.AttrValue'ListValue
|
|
instance TensorFlow.Types.Attribute [Proto.Tensorflow.Core.Framework.Types.DataType]
|
|
instance TensorFlow.Types.Attribute [GHC.Int.Int64]
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Types.Float
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Types.Double
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Int.Int8
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Int.Int16
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Int.Int32
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Int.Int64
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Word.Word8
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Word.Word16
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector GHC.Types.Bool
|
|
instance (Foreign.Storable.Storable a, TensorFlow.Types.TensorDataType Data.Vector.Storable.Vector a, TensorFlow.Types.TensorType a) => TensorFlow.Types.TensorDataType Data.Vector.Vector a
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Vector (Data.Complex.Complex GHC.Types.Float)
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Vector (Data.Complex.Complex GHC.Types.Double)
|
|
instance TensorFlow.Types.TensorDataType Data.Vector.Vector Data.ByteString.Internal.ByteString
|
|
instance (TensorFlow.Types.TensorDataType Data.Vector.Vector a, TensorFlow.Types.TensorType a) => TensorFlow.Types.TensorDataType TensorFlow.Types.Scalar a
|
|
instance GHC.Exts.IsList TensorFlow.Types.Shape
|
|
instance TensorFlow.Types.TensorType GHC.Types.Float
|
|
instance TensorFlow.Types.TensorType GHC.Types.Double
|
|
instance TensorFlow.Types.TensorType GHC.Int.Int32
|
|
instance TensorFlow.Types.TensorType GHC.Int.Int64
|
|
instance TensorFlow.Types.TensorType GHC.Word.Word8
|
|
instance TensorFlow.Types.TensorType GHC.Word.Word16
|
|
instance TensorFlow.Types.TensorType GHC.Word.Word32
|
|
instance TensorFlow.Types.TensorType GHC.Word.Word64
|
|
instance TensorFlow.Types.TensorType GHC.Int.Int16
|
|
instance TensorFlow.Types.TensorType GHC.Int.Int8
|
|
instance TensorFlow.Types.TensorType Data.ByteString.Internal.ByteString
|
|
instance TensorFlow.Types.TensorType GHC.Types.Bool
|
|
instance TensorFlow.Types.TensorType (Data.Complex.Complex GHC.Types.Float)
|
|
instance TensorFlow.Types.TensorType (Data.Complex.Complex GHC.Types.Double)
|
|
instance TensorFlow.Types.TensorType TensorFlow.Types.ResourceHandle
|
|
instance TensorFlow.Types.TensorType TensorFlow.Types.Variant
|
|
|
|
module TensorFlow.Output
|
|
|
|
-- | A type of graph node which has no outputs. These nodes are valuable
|
|
-- for causing side effects when they are run.
|
|
newtype ControlNode
|
|
ControlNode :: NodeName -> ControlNode
|
|
[unControlNode] :: ControlNode -> NodeName
|
|
|
|
-- | A device that a node can be assigned to. There's a naming convention
|
|
-- where the device names are constructed from job and replica names.
|
|
newtype Device
|
|
Device :: Text -> Device
|
|
[deviceName] :: Device -> Text
|
|
|
|
-- | The name of a node in the graph. This corresponds to the proto field
|
|
-- NodeDef.name. Includes the scope prefix (if any) and a unique
|
|
-- identifier (if the node was implicitly named).
|
|
newtype NodeName
|
|
NodeName :: Text -> NodeName
|
|
[unNodeName] :: NodeName -> Text
|
|
|
|
-- | Op definition. This corresponds somewhat to the <tt>NodeDef</tt>
|
|
-- proto.
|
|
data OpDef
|
|
OpDef :: !PendingNodeName -> !OpType -> !Map Text AttrValue -> [Output] -> [NodeName] -> OpDef
|
|
[_opName] :: OpDef -> !PendingNodeName
|
|
[_opType] :: OpDef -> !OpType
|
|
[_opAttrs] :: OpDef -> !Map Text AttrValue
|
|
[_opInputs] :: OpDef -> [Output]
|
|
[_opControlInputs] :: OpDef -> [NodeName]
|
|
opName :: Lens' OpDef PendingNodeName
|
|
opType :: Lens' OpDef OpType
|
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
|
opInputs :: Lens' OpDef [Output]
|
|
opControlInputs :: Lens' OpDef [NodeName]
|
|
|
|
-- | The type of op of a node in the graph. This corresponds to the proto
|
|
-- field NodeDef.op.
|
|
newtype OpType
|
|
OpType :: Text -> OpType
|
|
[unOpType] :: OpType -> Text
|
|
newtype OutputIx
|
|
OutputIx :: Int -> OutputIx
|
|
[unOutputIx] :: OutputIx -> Int
|
|
|
|
-- | An output of a TensorFlow node.
|
|
data Output
|
|
Output :: !OutputIx -> !NodeName -> Output
|
|
[outputIndex] :: Output -> !OutputIx
|
|
[outputNodeName] :: Output -> !NodeName
|
|
output :: OutputIx -> NodeName -> Output
|
|
|
|
-- | The name specified for an unrendered Op. If an Op has an ImplicitName,
|
|
-- it will be assigned based on the opType plus a unique identifier. Does
|
|
-- not contain the "scope" prefix.
|
|
data PendingNodeName
|
|
ExplicitName :: !Text -> PendingNodeName
|
|
ImplicitName :: PendingNodeName
|
|
instance GHC.Classes.Ord TensorFlow.Output.OpDef
|
|
instance GHC.Classes.Eq TensorFlow.Output.OpDef
|
|
instance GHC.Show.Show TensorFlow.Output.Output
|
|
instance GHC.Classes.Ord TensorFlow.Output.Output
|
|
instance GHC.Classes.Eq TensorFlow.Output.Output
|
|
instance GHC.Show.Show TensorFlow.Output.NodeName
|
|
instance GHC.Classes.Ord TensorFlow.Output.NodeName
|
|
instance GHC.Classes.Eq TensorFlow.Output.NodeName
|
|
instance GHC.Show.Show TensorFlow.Output.PendingNodeName
|
|
instance GHC.Classes.Ord TensorFlow.Output.PendingNodeName
|
|
instance GHC.Classes.Eq TensorFlow.Output.PendingNodeName
|
|
instance Data.String.IsString TensorFlow.Output.Device
|
|
instance GHC.Classes.Ord TensorFlow.Output.Device
|
|
instance GHC.Classes.Eq TensorFlow.Output.Device
|
|
instance GHC.Show.Show TensorFlow.Output.OutputIx
|
|
instance GHC.Enum.Enum TensorFlow.Output.OutputIx
|
|
instance GHC.Num.Num TensorFlow.Output.OutputIx
|
|
instance GHC.Classes.Ord TensorFlow.Output.OutputIx
|
|
instance GHC.Classes.Eq TensorFlow.Output.OutputIx
|
|
instance GHC.Show.Show TensorFlow.Output.OpType
|
|
instance GHC.Classes.Ord TensorFlow.Output.OpType
|
|
instance GHC.Classes.Eq TensorFlow.Output.OpType
|
|
instance Data.String.IsString TensorFlow.Output.Output
|
|
instance Data.String.IsString TensorFlow.Output.PendingNodeName
|
|
instance GHC.Show.Show TensorFlow.Output.Device
|
|
instance Data.String.IsString TensorFlow.Output.OpType
|
|
|
|
module TensorFlow.Build
|
|
|
|
-- | A type of graph node which has no outputs. These nodes are valuable
|
|
-- for causing side effects when they are run.
|
|
newtype ControlNode
|
|
ControlNode :: NodeName -> ControlNode
|
|
[unControlNode] :: ControlNode -> NodeName
|
|
data Unique
|
|
explicitName :: Text -> PendingNodeName
|
|
implicitName :: PendingNodeName
|
|
opDef :: OpType -> OpDef
|
|
opDefWithName :: PendingNodeName -> OpType -> OpDef
|
|
opName :: Lens' OpDef PendingNodeName
|
|
opType :: Lens' OpDef OpType
|
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
|
opInputs :: Lens' OpDef [Output]
|
|
opControlInputs :: Lens' OpDef [NodeName]
|
|
data GraphState
|
|
renderedNodeDefs :: Lens' GraphState (Map NodeName NodeDef)
|
|
|
|
-- | An action for building nodes in a TensorFlow graph. Used to manage
|
|
-- build state internally as part of the <tt>Session</tt> monad.
|
|
data BuildT m a
|
|
|
|
-- | An action for building nodes in a TensorFlow graph.
|
|
type Build = BuildT Identity
|
|
|
|
-- | Lift a <a>Build</a> action into a monad, including any explicit op
|
|
-- renderings.
|
|
class Monad m => MonadBuild m
|
|
build :: MonadBuild m => Build a -> m a
|
|
|
|
-- | Registers the given node to be executed before the next <a>run</a>.
|
|
addInitializer :: MonadBuild m => ControlNode -> m ()
|
|
|
|
-- | This is Control.Monad.Morph.hoist sans the dependency.
|
|
hoistBuildT :: (forall a. m a -> n a) -> BuildT m b -> BuildT n b
|
|
evalBuildT :: Monad m => BuildT m a -> m a
|
|
runBuildT :: BuildT m a -> m (a, GraphState)
|
|
|
|
-- | Produce a GraphDef proto representation of the nodes that are rendered
|
|
-- in the given <a>Build</a> action.
|
|
asGraphDef :: Build a -> GraphDef
|
|
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
|
|
|
-- | Get all the initializers that have accumulated so far, and clear that
|
|
-- buffer.
|
|
flushInitializers :: Monad m => BuildT m [NodeName]
|
|
|
|
-- | Get all the NodeDefs that have accumulated so far, and clear that
|
|
-- buffer.
|
|
flushNodeBuffer :: MonadBuild m => m [NodeDef]
|
|
summaries :: Lens' GraphState [Output]
|
|
|
|
-- | Render the given op if it hasn't been rendered already, and return its
|
|
-- name.
|
|
getOrAddOp :: OpDef -> Build NodeName
|
|
|
|
-- | Add a new node for a given <a>OpDef</a>. This is used for making
|
|
-- "stateful" ops which are not safe to dedup (e.g, "variable" and
|
|
-- "assign").
|
|
addNewOp :: OpDef -> Build NodeName
|
|
|
|
-- | Turn an <a>Output</a> into a string representation for the TensorFlow
|
|
-- foreign APIs.
|
|
encodeOutput :: Output -> Text
|
|
lookupNode :: NodeName -> Build NodeDef
|
|
|
|
-- | Modify some part of the state, run an action, and restore the state
|
|
-- after that action is done.
|
|
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
|
|
|
|
-- | Set a device for all nodes rendered in the given <a>Build</a> action
|
|
-- (unless further overridden by another use of withDevice).
|
|
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
|
|
|
-- | Prepend a scope to all nodes rendered in the given <a>Build</a>
|
|
-- action.
|
|
withNameScope :: MonadBuild m => Text -> m a -> m a
|
|
|
|
-- | Add control inputs to all nodes rendered in the given <a>Build</a>
|
|
-- action.
|
|
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
|
instance Control.Monad.Fail.MonadFail m => Control.Monad.Fail.MonadFail (TensorFlow.Build.BuildT m)
|
|
instance Control.Monad.Fix.MonadFix m => Control.Monad.Fix.MonadFix (TensorFlow.Build.BuildT m)
|
|
instance Control.Monad.Catch.MonadMask m => Control.Monad.Catch.MonadMask (TensorFlow.Build.BuildT m)
|
|
instance Control.Monad.Catch.MonadCatch m => Control.Monad.Catch.MonadCatch (TensorFlow.Build.BuildT m)
|
|
instance Control.Monad.Catch.MonadThrow m => Control.Monad.Catch.MonadThrow (TensorFlow.Build.BuildT m)
|
|
instance GHC.Base.Monad m => Control.Monad.State.Class.MonadState TensorFlow.Build.GraphState (TensorFlow.Build.BuildT m)
|
|
instance Control.Monad.Trans.Class.MonadTrans TensorFlow.Build.BuildT
|
|
instance Control.Monad.IO.Class.MonadIO m => Control.Monad.IO.Class.MonadIO (TensorFlow.Build.BuildT m)
|
|
instance GHC.Base.Monad m => GHC.Base.Monad (TensorFlow.Build.BuildT m)
|
|
instance GHC.Base.Monad m => GHC.Base.Applicative (TensorFlow.Build.BuildT m)
|
|
instance GHC.Base.Functor m => GHC.Base.Functor (TensorFlow.Build.BuildT m)
|
|
instance GHC.Classes.Ord TensorFlow.Build.PendingNode
|
|
instance GHC.Classes.Eq TensorFlow.Build.PendingNode
|
|
instance Data.String.IsString TensorFlow.Build.Scope
|
|
instance GHC.Classes.Ord TensorFlow.Build.Scope
|
|
instance GHC.Classes.Eq TensorFlow.Build.Scope
|
|
instance GHC.Enum.Enum TensorFlow.Build.Unique
|
|
instance GHC.Classes.Ord TensorFlow.Build.Unique
|
|
instance GHC.Classes.Eq TensorFlow.Build.Unique
|
|
instance GHC.Base.Monad m => TensorFlow.Build.MonadBuild (TensorFlow.Build.BuildT m)
|
|
instance GHC.Show.Show TensorFlow.Build.Scope
|
|
|
|
module TensorFlow.Tensor
|
|
|
|
-- | A named output of a TensorFlow operation.
|
|
--
|
|
-- The type parameter <tt>a</tt> is the type of the elements in the
|
|
-- <a>Tensor</a>. The parameter <tt>v</tt> is either:
|
|
--
|
|
-- <ul>
|
|
-- <li><a>Build</a>: An unrendered, immutable value.</li>
|
|
-- <li><a>Value</a>: A rendered, immutable value.</li>
|
|
-- <li><a>Ref</a>: A rendered stateful handle (e.g., a variable).</li>
|
|
-- </ul>
|
|
--
|
|
-- Note that <a>expr</a>, <a>value</a>, <a>render</a> and
|
|
-- <a>renderValue</a> can help convert between the different types of
|
|
-- <a>Tensor</a>.
|
|
data Tensor v a
|
|
[Tensor] :: TensorKind v => {tensorOutput :: v Output} -> Tensor v a
|
|
newtype Value a
|
|
Value :: a -> Value a
|
|
[runValue] :: Value a -> a
|
|
newtype Ref a
|
|
Ref :: a -> Ref a
|
|
[runRef] :: Ref a -> a
|
|
|
|
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
|
|
value :: Tensor Ref a -> Tensor Value a
|
|
renderValue :: MonadBuild m => Tensor v a -> m (Tensor Value a)
|
|
|
|
-- | A pair of a <a>Tensor</a> and some data that should be fed into that
|
|
-- <a>Tensor</a> when running the graph.
|
|
data Feed
|
|
Feed :: Output -> TensorData -> Feed
|
|
|
|
-- | A class ensuring that a given tensor is rendered, i.e., has a fixed
|
|
-- name, device, etc.
|
|
class Rendered t
|
|
renderedOutput :: Rendered t => t a -> Output
|
|
tensorNodeName :: Rendered t => t a -> NodeName
|
|
|
|
-- | Create a <a>Feed</a> for feeding the given data into a <a>Tensor</a>
|
|
-- when running the graph.
|
|
--
|
|
-- Note that if a <a>Tensor</a> is rendered, its identity may change; so
|
|
-- feeding the rendered <a>Tensor</a> may be different than feeding the
|
|
-- original <a>Tensor</a>.
|
|
feed :: Rendered t => t a -> TensorData a -> Feed
|
|
|
|
-- | Create a <a>Tensor</a> for a given name. This can be used to reference
|
|
-- nodes in a <tt>GraphDef</tt> that was loaded via <a>addGraphDef</a>.
|
|
-- TODO(judahjacobson): add more safety checks here.
|
|
tensorFromName :: TensorKind v => Text -> Tensor v a
|
|
|
|
-- | Like <a>tensorFromName</a>, but type-restricted to <a>Value</a>.
|
|
tensorValueFromName :: Text -> Tensor Value a
|
|
|
|
-- | Like <a>tensorFromName</a>, but type-restricted to <a>Ref</a>.
|
|
tensorRefFromName :: Text -> Tensor Ref a
|
|
type TensorList v = ListOf (Tensor v)
|
|
tensorListOutputs :: Rendered (Tensor v) => TensorList v as -> [Output]
|
|
|
|
-- | Places all nodes rendered in the given <a>Build</a> action on the same
|
|
-- device as the given Tensor (see also <a>withDevice</a>). Make sure
|
|
-- that the action has side effects of rendering the desired tensors. A
|
|
-- pure return would not have the desired effect.
|
|
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
|
|
|
|
-- | Render a <a>Tensor</a>, fixing its name, scope, device and control
|
|
-- inputs from the <a>MonadBuild</a> context. Also renders any
|
|
-- dependencies of the <a>Tensor</a> that weren't already rendered.
|
|
--
|
|
-- This operation is idempotent; calling <a>render</a> on the same input
|
|
-- in the same context will produce the same result. However, rendering
|
|
-- the same <tt>Tensor Build</tt> in two different contexts may result in
|
|
-- two different <tt>Tensor Value</tt>s.
|
|
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
|
|
expr :: TensorKind v => Tensor v a -> Tensor Build a
|
|
|
|
-- | Records the given summary action in Build for retrieval with Summary
|
|
-- protocol buffer in string form. For safety, use the pre-composed
|
|
-- functions: Logging.scalarSummary and Logging.histogramSummary.
|
|
addSummary :: (MonadBuild m, TensorKind v) => Tensor v ByteString -> m ()
|
|
|
|
-- | Retrieves the summary ops collected thus far. Typically this only
|
|
-- happens once, but if <a>buildWithSummary</a> is used repeatedly, the
|
|
-- values accumulate.
|
|
collectAllSummaries :: MonadBuild m => m [SummaryTensor]
|
|
|
|
-- | Synonym for the tensors that return serialized Summary proto.
|
|
type SummaryTensor = Tensor Value ByteString
|
|
|
|
-- | An internal class for kinds of Tensors.
|
|
class Monad v => TensorKind v
|
|
toBuild :: TensorKind v => v a -> Build a
|
|
|
|
-- | Types which can be converted to <a>Tensor</a>.
|
|
class ToTensor t
|
|
toTensor :: (ToTensor t, TensorType a) => t a -> Tensor Build a
|
|
instance GHC.Base.Functor TensorFlow.Tensor.Ref
|
|
instance GHC.Base.Functor TensorFlow.Tensor.Value
|
|
instance TensorFlow.Tensor.TensorKind v => TensorFlow.Tensor.ToTensor (TensorFlow.Tensor.Tensor v)
|
|
instance TensorFlow.Tensor.Rendered (TensorFlow.Tensor.Tensor TensorFlow.Tensor.Value)
|
|
instance TensorFlow.Tensor.Rendered (TensorFlow.Tensor.Tensor TensorFlow.Tensor.Ref)
|
|
instance TensorFlow.Tensor.TensorKind TensorFlow.Tensor.Value
|
|
instance TensorFlow.Tensor.TensorKind TensorFlow.Tensor.Ref
|
|
instance TensorFlow.Tensor.TensorKind TensorFlow.Build.Build
|
|
instance GHC.Base.Applicative TensorFlow.Tensor.Ref
|
|
instance GHC.Base.Monad TensorFlow.Tensor.Ref
|
|
instance GHC.Base.Applicative TensorFlow.Tensor.Value
|
|
instance GHC.Base.Monad TensorFlow.Tensor.Value
|
|
|
|
module TensorFlow.Nodes
|
|
|
|
-- | Types that contain ops which can be run.
|
|
class Nodes t
|
|
getNodes :: Nodes t => t -> Build (Set NodeName)
|
|
|
|
-- | Types that tensor representations (e.g. <a>Tensor</a>,
|
|
-- <a>ControlNode</a>) can be fetched into.
|
|
--
|
|
-- Includes collections of tensors (e.g. tuples).
|
|
class Nodes t => Fetchable t a
|
|
getFetch :: Fetchable t a => t -> Build (Fetch a)
|
|
|
|
-- | Fetch action. Keeps track of what needs to be fetched and how to
|
|
-- decode the fetched data.
|
|
data Fetch a
|
|
Fetch :: Set Text -> (Map Text TensorData -> a) -> Fetch a
|
|
|
|
-- | Nodes to fetch
|
|
[fetches] :: Fetch a -> Set Text
|
|
|
|
-- | Function to create an <tt>a</tt> from the fetched data.
|
|
[fetchRestore] :: Fetch a -> Map Text TensorData -> a
|
|
nodesUnion :: (Monoid b, Traversable t, Applicative f) => t (f b) -> f b
|
|
fetchTensorVector :: forall a v. TensorType a => Tensor v a -> Build (Fetch (TensorData a))
|
|
instance (TensorFlow.Nodes.Fetchable t1 a1, TensorFlow.Nodes.Fetchable t2 a2) => TensorFlow.Nodes.Fetchable (t1, t2) (a1, a2)
|
|
instance (TensorFlow.Nodes.Fetchable t1 a1, TensorFlow.Nodes.Fetchable t2 a2, TensorFlow.Nodes.Fetchable t3 a3) => TensorFlow.Nodes.Fetchable (t1, t2, t3) (a1, a2, a3)
|
|
instance TensorFlow.Nodes.Fetchable t a => TensorFlow.Nodes.Fetchable [t] [a]
|
|
instance TensorFlow.Nodes.Fetchable t a => TensorFlow.Nodes.Fetchable (GHC.Maybe.Maybe t) (GHC.Maybe.Maybe a)
|
|
instance (a GHC.Types.~ ()) => TensorFlow.Nodes.Fetchable TensorFlow.Output.ControlNode a
|
|
instance (l GHC.Types.~ TensorFlow.Types.List '[]) => TensorFlow.Nodes.Fetchable (TensorFlow.Types.ListOf f '[]) l
|
|
instance (TensorFlow.Nodes.Fetchable (f t) a, TensorFlow.Nodes.Fetchable (TensorFlow.Types.ListOf f ts) (TensorFlow.Types.List as), i GHC.Types.~ Data.Functor.Identity.Identity) => TensorFlow.Nodes.Fetchable (TensorFlow.Types.ListOf f (t : ts)) (TensorFlow.Types.ListOf i (a : as))
|
|
instance (TensorFlow.Types.TensorType a, a GHC.Types.~ a') => TensorFlow.Nodes.Fetchable (TensorFlow.Tensor.Tensor v a) (TensorFlow.Types.TensorData a')
|
|
instance (TensorFlow.Types.TensorType a, TensorFlow.Types.TensorDataType s a, a GHC.Types.~ a') => TensorFlow.Nodes.Fetchable (TensorFlow.Tensor.Tensor v a) (s a')
|
|
instance GHC.Base.Functor TensorFlow.Nodes.Fetch
|
|
instance GHC.Base.Applicative TensorFlow.Nodes.Fetch
|
|
instance (TensorFlow.Nodes.Nodes t1, TensorFlow.Nodes.Nodes t2) => TensorFlow.Nodes.Nodes (t1, t2)
|
|
instance (TensorFlow.Nodes.Nodes t1, TensorFlow.Nodes.Nodes t2, TensorFlow.Nodes.Nodes t3) => TensorFlow.Nodes.Nodes (t1, t2, t3)
|
|
instance TensorFlow.Nodes.Nodes t => TensorFlow.Nodes.Nodes [t]
|
|
instance TensorFlow.Nodes.Nodes t => TensorFlow.Nodes.Nodes (GHC.Maybe.Maybe t)
|
|
instance TensorFlow.Nodes.Nodes TensorFlow.Output.ControlNode
|
|
instance TensorFlow.Nodes.Nodes (TensorFlow.Types.ListOf f '[])
|
|
instance (TensorFlow.Nodes.Nodes (f a), TensorFlow.Nodes.Nodes (TensorFlow.Types.ListOf f as)) => TensorFlow.Nodes.Nodes (TensorFlow.Types.ListOf f (a : as))
|
|
instance TensorFlow.Nodes.Nodes (TensorFlow.Tensor.Tensor v a)
|
|
|
|
module TensorFlow.Session
|
|
type Session = SessionT IO
|
|
data SessionT m a
|
|
|
|
-- | Customization for session. Use the lenses to update:
|
|
-- <a>sessionTarget</a>, <a>sessionTracer</a>, <a>sessionConfig</a>.
|
|
data Options
|
|
|
|
-- | Uses the specified config for the created session.
|
|
sessionConfig :: Lens' Options ConfigProto
|
|
|
|
-- | Target can be: "local", ip:port, host:port. The set of supported
|
|
-- factories depends on the linked in libraries.
|
|
sessionTarget :: Lens' Options ByteString
|
|
|
|
-- | Uses the given logger to monitor session progress.
|
|
sessionTracer :: Lens' Options Tracer
|
|
|
|
-- | Run <a>Session</a> actions in a new TensorFlow session.
|
|
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
|
|
|
|
-- | Run <a>Session</a> actions in a new TensorFlow session created with
|
|
-- the given option setter actions (<a>sessionTarget</a>,
|
|
-- <a>sessionConfig</a>).
|
|
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
|
|
|
-- | Lift a <a>Build</a> action into a monad, including any explicit op
|
|
-- renderings.
|
|
class Monad m => MonadBuild m
|
|
build :: MonadBuild m => Build a -> m a
|
|
|
|
-- | Add all pending rendered nodes to the TensorFlow graph and runs any
|
|
-- pending initializers.
|
|
--
|
|
-- Note that run, runWithFeeds, etc. will all call this function
|
|
-- implicitly.
|
|
extend :: MonadIO m => SessionT m ()
|
|
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, and fetch the corresponding values for <tt>a</tt>.
|
|
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, feed the given input values, and fetch the
|
|
-- corresponding result values for <tt>a</tt>.
|
|
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering and extending any dependent nodes
|
|
-- that aren't already rendered. This behaves like <a>run</a> except that
|
|
-- it doesn't do any fetches.
|
|
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, feed the given input values, and fetch the
|
|
-- corresponding result values for <tt>a</tt>. This behaves like
|
|
-- <a>runWithFeeds</a> except that it doesn't do any fetches.
|
|
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
|
|
|
|
-- | Starts a concurrent thread which evaluates the given Nodes forever
|
|
-- until runSession exits or an exception occurs. Graph extension happens
|
|
-- synchronously, but the resultant run proceeds as a separate thread.
|
|
asyncProdNodes :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
|
instance Control.Monad.Fail.MonadFail m => Control.Monad.Fail.MonadFail (TensorFlow.Session.SessionT m)
|
|
instance Control.Monad.Catch.MonadMask m => Control.Monad.Catch.MonadMask (TensorFlow.Session.SessionT m)
|
|
instance Control.Monad.Catch.MonadCatch m => Control.Monad.Catch.MonadCatch (TensorFlow.Session.SessionT m)
|
|
instance Control.Monad.Catch.MonadThrow m => Control.Monad.Catch.MonadThrow (TensorFlow.Session.SessionT m)
|
|
instance Control.Monad.IO.Class.MonadIO m => Control.Monad.IO.Class.MonadIO (TensorFlow.Session.SessionT m)
|
|
instance GHC.Base.Monad m => GHC.Base.Monad (TensorFlow.Session.SessionT m)
|
|
instance GHC.Base.Monad m => GHC.Base.Applicative (TensorFlow.Session.SessionT m)
|
|
instance GHC.Base.Functor m => GHC.Base.Functor (TensorFlow.Session.SessionT m)
|
|
instance Data.Default.Class.Default TensorFlow.Session.Options
|
|
instance Control.Monad.Trans.Class.MonadTrans TensorFlow.Session.SessionT
|
|
instance GHC.Base.Monad m => TensorFlow.Build.MonadBuild (TensorFlow.Session.SessionT m)
|
|
|
|
module TensorFlow.BuildOp
|
|
|
|
-- | Class of types that can be used as op outputs.
|
|
class BuildResult a
|
|
buildResult :: BuildResult a => Result a
|
|
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
|
|
|
|
-- | Class of types that can be used as op outputs.
|
|
class PureResult a
|
|
pureResult :: PureResult a => ReaderT (Build OpDef) (State ResultState) a
|
|
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
|
|
|
|
-- | Returns true if all the integers in each tuple are identical. Throws
|
|
-- an error with a descriptive message if not.
|
|
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
|
class BuildInputs a
|
|
buildInputs :: BuildInputs a => a -> Build [Output]
|
|
|
|
-- | Parameters to build an op (for example, the node name or optional
|
|
-- attributes). TODO: be more type safe.
|
|
type OpParams = OpDef -> OpDef
|
|
instance GHC.Show.Show TensorFlow.BuildOp.ResultState
|
|
instance TensorFlow.BuildOp.BuildInputs a => TensorFlow.BuildOp.BuildInputs [a]
|
|
instance TensorFlow.BuildOp.BuildInputs (TensorFlow.Tensor.Tensor v a)
|
|
instance TensorFlow.BuildOp.BuildInputs (TensorFlow.Types.ListOf (TensorFlow.Tensor.Tensor v) as)
|
|
instance TensorFlow.BuildOp.PureResult (TensorFlow.Tensor.Tensor TensorFlow.Build.Build a)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2) => TensorFlow.BuildOp.PureResult (a1, a2)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3) => TensorFlow.BuildOp.PureResult (a1, a2, a3)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3, TensorFlow.BuildOp.PureResult a4) => TensorFlow.BuildOp.PureResult (a1, a2, a3, a4)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3, TensorFlow.BuildOp.PureResult a4, TensorFlow.BuildOp.PureResult a5) => TensorFlow.BuildOp.PureResult (a1, a2, a3, a4, a5)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3, TensorFlow.BuildOp.PureResult a4, TensorFlow.BuildOp.PureResult a5, TensorFlow.BuildOp.PureResult a6) => TensorFlow.BuildOp.PureResult (a1, a2, a3, a4, a5, a6)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3, TensorFlow.BuildOp.PureResult a4, TensorFlow.BuildOp.PureResult a5, TensorFlow.BuildOp.PureResult a6, TensorFlow.BuildOp.PureResult a7) => TensorFlow.BuildOp.PureResult (a1, a2, a3, a4, a5, a6, a7)
|
|
instance (TensorFlow.BuildOp.PureResult a1, TensorFlow.BuildOp.PureResult a2, TensorFlow.BuildOp.PureResult a3, TensorFlow.BuildOp.PureResult a4, TensorFlow.BuildOp.PureResult a5, TensorFlow.BuildOp.PureResult a6, TensorFlow.BuildOp.PureResult a7, TensorFlow.BuildOp.PureResult a8) => TensorFlow.BuildOp.PureResult (a1, a2, a3, a4, a5, a6, a7, a8)
|
|
instance TensorFlow.BuildOp.PureResult a => TensorFlow.BuildOp.PureResult [a]
|
|
instance TensorFlow.Types.TensorTypes as => TensorFlow.BuildOp.PureResult (TensorFlow.Tensor.TensorList TensorFlow.Build.Build as)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2) => TensorFlow.BuildOp.BuildResult (a1, a2)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3) => TensorFlow.BuildOp.BuildResult (a1, a2, a3)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3, TensorFlow.BuildOp.BuildResult a4) => TensorFlow.BuildOp.BuildResult (a1, a2, a3, a4)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3, TensorFlow.BuildOp.BuildResult a4, TensorFlow.BuildOp.BuildResult a5) => TensorFlow.BuildOp.BuildResult (a1, a2, a3, a4, a5)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3, TensorFlow.BuildOp.BuildResult a4, TensorFlow.BuildOp.BuildResult a5, TensorFlow.BuildOp.BuildResult a6) => TensorFlow.BuildOp.BuildResult (a1, a2, a3, a4, a5, a6)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3, TensorFlow.BuildOp.BuildResult a4, TensorFlow.BuildOp.BuildResult a5, TensorFlow.BuildOp.BuildResult a6, TensorFlow.BuildOp.BuildResult a7) => TensorFlow.BuildOp.BuildResult (a1, a2, a3, a4, a5, a6, a7)
|
|
instance (TensorFlow.BuildOp.BuildResult a1, TensorFlow.BuildOp.BuildResult a2, TensorFlow.BuildOp.BuildResult a3, TensorFlow.BuildOp.BuildResult a4, TensorFlow.BuildOp.BuildResult a5, TensorFlow.BuildOp.BuildResult a6, TensorFlow.BuildOp.BuildResult a7, TensorFlow.BuildOp.BuildResult a8) => TensorFlow.BuildOp.BuildResult (a1, a2, a3, a4, a5, a6, a7, a8)
|
|
instance (TensorFlow.Tensor.TensorKind v, TensorFlow.Tensor.Rendered (TensorFlow.Tensor.Tensor v)) => TensorFlow.BuildOp.BuildResult (TensorFlow.Tensor.Tensor v a)
|
|
instance TensorFlow.BuildOp.BuildResult TensorFlow.Output.ControlNode
|
|
instance (TensorFlow.Tensor.TensorKind v, TensorFlow.Tensor.Rendered (TensorFlow.Tensor.Tensor v), TensorFlow.Types.TensorTypes as) => TensorFlow.BuildOp.BuildResult (TensorFlow.Tensor.TensorList v as)
|
|
instance TensorFlow.BuildOp.BuildResult a => TensorFlow.BuildOp.BuildResult [a]
|
|
|
|
module TensorFlow.ControlFlow
|
|
|
|
-- | Modify a <a>Build</a> action, such that all new ops rendered in it
|
|
-- will depend on the nodes in the first argument.
|
|
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
|
|
|
|
-- | Create an op that groups multiple operations.
|
|
--
|
|
-- When this op finishes, all ops in the input <tt>n</tt> have finished.
|
|
-- This op has no output.
|
|
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
|
|
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
|
noOp :: MonadBuild m => m ControlNode
|
|
|
|
|
|
-- | The core functionality of TensorFlow.
|
|
--
|
|
-- Unless you are defining ops, you do not need to import other modules
|
|
-- from this package.
|
|
--
|
|
-- Basic ops are provided in the tensorflow-ops and tensorflow-core-ops
|
|
-- packages.
|
|
module TensorFlow.Core
|
|
type Session = SessionT IO
|
|
|
|
-- | Customization for session. Use the lenses to update:
|
|
-- <a>sessionTarget</a>, <a>sessionTracer</a>, <a>sessionConfig</a>.
|
|
data Options
|
|
|
|
-- | Uses the specified config for the created session.
|
|
sessionConfig :: Lens' Options ConfigProto
|
|
|
|
-- | Target can be: "local", ip:port, host:port. The set of supported
|
|
-- factories depends on the linked in libraries.
|
|
sessionTarget :: Lens' Options ByteString
|
|
|
|
-- | Uses the given logger to monitor session progress.
|
|
sessionTracer :: Lens' Options Tracer
|
|
|
|
-- | Run <a>Session</a> actions in a new TensorFlow session.
|
|
runSession :: (MonadMask m, MonadIO m) => SessionT m a -> m a
|
|
|
|
-- | Run <a>Session</a> actions in a new TensorFlow session created with
|
|
-- the given option setter actions (<a>sessionTarget</a>,
|
|
-- <a>sessionConfig</a>).
|
|
runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a -> m a
|
|
|
|
-- | Lift a <a>Build</a> action into a monad, including any explicit op
|
|
-- renderings.
|
|
class Monad m => MonadBuild m
|
|
build :: MonadBuild m => Build a -> m a
|
|
|
|
-- | Types that tensor representations (e.g. <a>Tensor</a>,
|
|
-- <a>ControlNode</a>) can be fetched into.
|
|
--
|
|
-- Includes collections of tensors (e.g. tuples).
|
|
class Nodes t => Fetchable t a
|
|
|
|
-- | Types that contain ops which can be run.
|
|
class Nodes t
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, and fetch the corresponding values for <tt>a</tt>.
|
|
run :: (MonadIO m, Fetchable t a) => t -> SessionT m a
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering and extending any dependent nodes
|
|
-- that aren't already rendered. This behaves like <a>run</a> except that
|
|
-- it doesn't do any fetches.
|
|
run_ :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
|
|
|
-- | A pair of a <a>Tensor</a> and some data that should be fed into that
|
|
-- <a>Tensor</a> when running the graph.
|
|
data Feed
|
|
|
|
-- | Create a <a>Feed</a> for feeding the given data into a <a>Tensor</a>
|
|
-- when running the graph.
|
|
--
|
|
-- Note that if a <a>Tensor</a> is rendered, its identity may change; so
|
|
-- feeding the rendered <a>Tensor</a> may be different than feeding the
|
|
-- original <a>Tensor</a>.
|
|
feed :: Rendered t => t a -> TensorData a -> Feed
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, feed the given input values, and fetch the
|
|
-- corresponding result values for <tt>a</tt>.
|
|
runWithFeeds :: (MonadIO m, Fetchable t a) => [Feed] -> t -> SessionT m a
|
|
|
|
-- | Run a subgraph <tt>t</tt>, rendering any dependent nodes that aren't
|
|
-- already rendered, feed the given input values, and fetch the
|
|
-- corresponding result values for <tt>a</tt>. This behaves like
|
|
-- <a>runWithFeeds</a> except that it doesn't do any fetches.
|
|
runWithFeeds_ :: (MonadIO m, Nodes t) => [Feed] -> t -> SessionT m ()
|
|
|
|
-- | Starts a concurrent thread which evaluates the given Nodes forever
|
|
-- until runSession exits or an exception occurs. Graph extension happens
|
|
-- synchronously, but the resultant run proceeds as a separate thread.
|
|
asyncProdNodes :: (MonadIO m, Nodes t) => t -> SessionT m ()
|
|
|
|
-- | An action for building nodes in a TensorFlow graph.
|
|
type Build = BuildT Identity
|
|
|
|
-- | An action for building nodes in a TensorFlow graph. Used to manage
|
|
-- build state internally as part of the <tt>Session</tt> monad.
|
|
data BuildT m a
|
|
|
|
-- | Render a <a>Tensor</a>, fixing its name, scope, device and control
|
|
-- inputs from the <a>MonadBuild</a> context. Also renders any
|
|
-- dependencies of the <a>Tensor</a> that weren't already rendered.
|
|
--
|
|
-- This operation is idempotent; calling <a>render</a> on the same input
|
|
-- in the same context will produce the same result. However, rendering
|
|
-- the same <tt>Tensor Build</tt> in two different contexts may result in
|
|
-- two different <tt>Tensor Value</tt>s.
|
|
render :: MonadBuild m => Tensor Build a -> m (Tensor Value a)
|
|
|
|
-- | Produce a GraphDef proto representation of the nodes that are rendered
|
|
-- in the given <a>Build</a> action.
|
|
asGraphDef :: Build a -> GraphDef
|
|
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
|
opName :: Lens' OpDef PendingNodeName
|
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
|
|
|
-- | Registers the given node to be executed before the next <a>run</a>.
|
|
addInitializer :: MonadBuild m => ControlNode -> m ()
|
|
|
|
-- | A type of graph node which has no outputs. These nodes are valuable
|
|
-- for causing side effects when they are run.
|
|
data ControlNode
|
|
|
|
-- | A named output of a TensorFlow operation.
|
|
--
|
|
-- The type parameter <tt>a</tt> is the type of the elements in the
|
|
-- <a>Tensor</a>. The parameter <tt>v</tt> is either:
|
|
--
|
|
-- <ul>
|
|
-- <li><a>Build</a>: An unrendered, immutable value.</li>
|
|
-- <li><a>Value</a>: A rendered, immutable value.</li>
|
|
-- <li><a>Ref</a>: A rendered stateful handle (e.g., a variable).</li>
|
|
-- </ul>
|
|
--
|
|
-- Note that <a>expr</a>, <a>value</a>, <a>render</a> and
|
|
-- <a>renderValue</a> can help convert between the different types of
|
|
-- <a>Tensor</a>.
|
|
data Tensor v a
|
|
data Value a
|
|
data Ref a
|
|
|
|
-- | Cast a 'Tensor Ref' into a 'Tensor Value'. This behaves like a no-op.
|
|
value :: Tensor Ref a -> Tensor Value a
|
|
|
|
-- | Create a <a>Tensor</a> for a given name. This can be used to reference
|
|
-- nodes in a <tt>GraphDef</tt> that was loaded via <a>addGraphDef</a>.
|
|
-- TODO(judahjacobson): add more safety checks here.
|
|
tensorFromName :: TensorKind v => Text -> Tensor v a
|
|
expr :: TensorKind v => Tensor v a -> Tensor Build a
|
|
|
|
-- | The class of scalar types supported by tensorflow.
|
|
class TensorType a
|
|
|
|
-- | Tensor data with the correct memory layout for tensorflow.
|
|
data TensorData a
|
|
|
|
-- | Types that can be converted to and from <a>TensorData</a>.
|
|
--
|
|
-- <a>Vector</a> is the most efficient to encode/decode for most element
|
|
-- types.
|
|
class TensorType a => TensorDataType s a
|
|
|
|
-- | Decode the bytes of a <a>TensorData</a> into an <a>s</a>.
|
|
decodeTensorData :: TensorDataType s a => TensorData a -> s a
|
|
|
|
-- | Encode an <a>s</a> into a <a>TensorData</a>.
|
|
--
|
|
-- The values should be in row major order, e.g.,
|
|
--
|
|
-- element 0: index (0, ..., 0) element 1: index (0, ..., 1) ...
|
|
encodeTensorData :: TensorDataType s a => Shape -> s a -> TensorData a
|
|
type ResourceHandle = ResourceHandleProto
|
|
newtype Scalar a
|
|
Scalar :: a -> Scalar a
|
|
[unScalar] :: Scalar a -> a
|
|
|
|
-- | Shape (dimensions) of a tensor.
|
|
--
|
|
-- TensorFlow supports shapes of unknown rank, which are represented as
|
|
-- <tt>Nothing :: Maybe Shape</tt> in Haskell.
|
|
newtype Shape
|
|
Shape :: [Int64] -> Shape
|
|
|
|
-- | A <a>Constraint</a> specifying the possible choices of a
|
|
-- <a>TensorType</a>.
|
|
--
|
|
-- We implement a <a>Constraint</a> like <tt>OneOf '[Double, Float]
|
|
-- a</tt> by turning the natural representation as a conjunction, i.e.,
|
|
--
|
|
-- <pre>
|
|
-- a == Double || a == Float
|
|
-- </pre>
|
|
--
|
|
-- into a disjunction like
|
|
--
|
|
-- <pre>
|
|
-- a /= Int32 && a /= Int64 && a /= ByteString && ...
|
|
-- </pre>
|
|
--
|
|
-- using an enumeration of all the possible <a>TensorType</a>s.
|
|
type OneOf ts a = (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)
|
|
|
|
-- | A constraint checking that two types are different.
|
|
type family a /= b :: Constraint
|
|
|
|
-- | Places all nodes rendered in the given <a>Build</a> action on the same
|
|
-- device as the given Tensor (see also <a>withDevice</a>). Make sure
|
|
-- that the action has side effects of rendering the desired tensors. A
|
|
-- pure return would not have the desired effect.
|
|
colocateWith :: (MonadBuild m, Rendered t) => t b -> m a -> m a
|
|
|
|
-- | A device that a node can be assigned to. There's a naming convention
|
|
-- where the device names are constructed from job and replica names.
|
|
newtype Device
|
|
Device :: Text -> Device
|
|
[deviceName] :: Device -> Text
|
|
|
|
-- | Set a device for all nodes rendered in the given <a>Build</a> action
|
|
-- (unless further overridden by another use of withDevice).
|
|
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
|
|
|
-- | Prepend a scope to all nodes rendered in the given <a>Build</a>
|
|
-- action.
|
|
withNameScope :: MonadBuild m => Text -> m a -> m a
|
|
|
|
-- | Modify a <a>Build</a> action, such that all new ops rendered in it
|
|
-- will depend on the nodes in the first argument.
|
|
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
|
|
|
|
-- | Create an op that groups multiple operations.
|
|
--
|
|
-- When this op finishes, all ops in the input <tt>n</tt> have finished.
|
|
-- This op has no output.
|
|
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
|
|
|
-- | Does nothing. Only useful as a placeholder for control edges.
|
|
noOp :: MonadBuild m => m ControlNode
|