mirror of
https://github.com/tensorflow/haskell.git
synced 2025-02-17 05:25:05 +01:00
Add resource-based variable ops. (#98)
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix #92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
This commit is contained in:
parent
21b723d542
commit
42f4fc647e
10 changed files with 231 additions and 48 deletions
|
@ -33,15 +33,14 @@ where:
|
||||||
(for example: 'TensorType' or 'OneOf').
|
(for example: 'TensorType' or 'OneOf').
|
||||||
|
|
||||||
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
* @{input tensors}@ is of the form @T_1 -> ... -> T_N@, where each @T@ is of
|
||||||
the form @Tensor Ref a@, @Tensor v a@ or @ResourceHandle@ (or a list of one
|
the form @Tensor Ref a@ or @Tensor v a@ (or a list of one of those types),
|
||||||
of those types), and @a@ is either a concrete type or a (constrained) type
|
and @a@ is either a concrete type or a (constrained) type variable.
|
||||||
variable.
|
|
||||||
|
|
||||||
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
|
* @{output tensors}@ is of the form @(T_1,...,T_N)@ for "pure" ops, and
|
||||||
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
|
@Build (T_1,...,T_N)@ for "stateful" ops. An op is considered "stateful" if
|
||||||
it takes a @Tensor Ref@ or @ResourceHandle@ as input, or if it's explicitly
|
it takes a @Tensor Ref@ or @Tensor v ResourceHandle@ as input, or if it's
|
||||||
marked \"Stateful\" in its @REGISTER_OP@ definition. (If there are no outputs,
|
explicitly marked \"Stateful\" in its @REGISTER_OP@ definition. (If there
|
||||||
it is either @ControlNode@ or @Build ControlNode@.)
|
are no outputs, it is either @ControlNode@ or @Build ControlNode@.)
|
||||||
-}
|
-}
|
||||||
|
|
||||||
module TensorFlow.OpGen
|
module TensorFlow.OpGen
|
||||||
|
@ -155,7 +154,6 @@ imports = stack [
|
||||||
, "import Lens.Family2 ((.~), (&))"
|
, "import Lens.Family2 ((.~), (&))"
|
||||||
, "import TensorFlow.Build"
|
, "import TensorFlow.Build"
|
||||||
, "import TensorFlow.BuildOp"
|
, "import TensorFlow.BuildOp"
|
||||||
, "import TensorFlow.Output (ResourceHandle)"
|
|
||||||
, "import TensorFlow.Tensor"
|
, "import TensorFlow.Tensor"
|
||||||
, "import TensorFlow.Types"
|
, "import TensorFlow.Types"
|
||||||
]
|
]
|
||||||
|
@ -300,7 +298,7 @@ typeSig pre pOp = constraints
|
||||||
| null classConstraints = empty
|
| null classConstraints = empty
|
||||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||||
Just (ArgSomeTensor v) <- [argKind $ parsedArgCase k]]
|
ArgSomeTensor v <- [argKind $ parsedArgCase k]]
|
||||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||||
|
@ -333,13 +331,12 @@ typeSig pre pOp = constraints
|
||||||
| otherwise = o
|
| otherwise = o
|
||||||
|
|
||||||
-- | Render an op input or output.
|
-- | Render an op input or output.
|
||||||
-- For example: "Tensor Ref Int64", "Tensor v t", "ResourceHandle"
|
-- For example: "Tensor Ref Int64", "Tensor v t"
|
||||||
tensorArg :: ParsedArg -> Doc
|
tensorArg :: ParsedArg -> Doc
|
||||||
tensorArg p = case parsedArgCase p of
|
tensorArg p = case parsedArgCase p of
|
||||||
ResourceArg -> "ResourceHandle"
|
SimpleArg { argType = t, argKind = k } -> tensorType t k
|
||||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
ListArg { argType = t, argKind = k } -> brackets $ tensorType t k
|
||||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
MixedListArg {argTypeAttr = t, argKind = k}
|
||||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
|
||||||
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
|
-> "TensorList" <+> parens (kind k) <+> renderHaskellName t
|
||||||
where
|
where
|
||||||
kind k = case k of
|
kind k = case k of
|
||||||
|
@ -420,8 +417,7 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
||||||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||||
dtTypeToHaskell DT_RESOURCE =
|
dtTypeToHaskell DT_RESOURCE = "ResourceHandle"
|
||||||
error "ResourceHandle must be prevented from getting here."
|
|
||||||
dtTypeToHaskell x =
|
dtTypeToHaskell x =
|
||||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ module TensorFlow.OpGen.ParsedOp
|
||||||
, ParsedArgCase(..)
|
, ParsedArgCase(..)
|
||||||
, ArgType(..)
|
, ArgType(..)
|
||||||
, ArgKind(..)
|
, ArgKind(..)
|
||||||
, argKind
|
|
||||||
, parseOp
|
, parseOp
|
||||||
, camelCase
|
, camelCase
|
||||||
) where
|
) where
|
||||||
|
@ -118,19 +117,18 @@ data ParsedArg = ParsedArg
|
||||||
}
|
}
|
||||||
|
|
||||||
data ParsedArgCase
|
data ParsedArgCase
|
||||||
= SimpleArg { argType :: ArgType, argCaseKind :: ArgKind }
|
= SimpleArg { argType :: ArgType, argKind :: ArgKind }
|
||||||
| ListArg
|
| ListArg
|
||||||
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
{ argLength :: Name -- ^ The attribute that specifies this list's length.
|
||||||
, argType :: ArgType
|
, argType :: ArgType
|
||||||
, argCaseKind :: ArgKind
|
, argKind :: ArgKind
|
||||||
}
|
}
|
||||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
| MixedListArg { argTypeAttr :: Name, argKind :: ArgKind }
|
||||||
-- ^ A heterogeneous list.
|
-- ^ A heterogeneous list.
|
||||||
| ResourceArg
|
|
||||||
|
|
||||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
maybeArgType :: ParsedArgCase -> Maybe ArgType
|
||||||
argKind ResourceArg = Nothing
|
maybeArgType MixedListArg{} = Nothing
|
||||||
argKind a = Just $ argCaseKind a
|
maybeArgType a = Just $ argType a
|
||||||
|
|
||||||
-- | The type of an argument.
|
-- | The type of an argument.
|
||||||
data ArgType
|
data ArgType
|
||||||
|
@ -146,10 +144,10 @@ data ArgKind
|
||||||
deriving (Eq)
|
deriving (Eq)
|
||||||
|
|
||||||
isRefCase :: ParsedArgCase -> Bool
|
isRefCase :: ParsedArgCase -> Bool
|
||||||
isRefCase a = case argKind a of
|
isRefCase a
|
||||||
Nothing -> True -- Resource
|
| ArgTensorRef <- argKind a = True
|
||||||
Just ArgTensorRef -> True
|
| Just (ArgTypeFixed DT_RESOURCE) <- maybeArgType a = True
|
||||||
_ -> False
|
| otherwise = False
|
||||||
|
|
||||||
makeName :: Text -> Name
|
makeName :: Text -> Name
|
||||||
makeName n = Name
|
makeName n = Name
|
||||||
|
@ -314,7 +312,6 @@ parseArg a tKind = ParsedArg
|
||||||
|
|
||||||
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
parseArgCase :: OpDef'ArgDef -> ArgKind -> ParsedArgCase
|
||||||
parseArgCase a tKind
|
parseArgCase a tKind
|
||||||
| a ^. type' == DT_RESOURCE = ResourceArg
|
|
||||||
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
| Just n <- maybeAttr (a ^. typeListAttr) = MixedListArg n tKind
|
||||||
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
| Just n <- maybeAttr (a ^. numberAttr) = ListArg n thisArgType tKind
|
||||||
| otherwise = SimpleArg thisArgType tKind
|
| otherwise = SimpleArg thisArgType tKind
|
||||||
|
|
109
tensorflow-ops/src/TensorFlow/Variable.hs
Normal file
109
tensorflow-ops/src/TensorFlow/Variable.hs
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
-- | An implementation of ResourceHandle-based variables.
|
||||||
|
--
|
||||||
|
-- The main difference between this and 'Ref'-based variables is
|
||||||
|
-- that reads are explicit, via the 'readValue' op.
|
||||||
|
--
|
||||||
|
-- TODO: given that distinction, figure out a good story around
|
||||||
|
-- gradients and save/restore. Then, merge this module into
|
||||||
|
-- TensorFlow.Ops.
|
||||||
|
{-# LANGUAGE RecursiveDo #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
module TensorFlow.Variable
|
||||||
|
( Variable
|
||||||
|
, variable
|
||||||
|
, variable'
|
||||||
|
, readValue
|
||||||
|
, initializedVariable
|
||||||
|
, initializedVariable'
|
||||||
|
, zeroInitializedVariable
|
||||||
|
, zeroInitializedVariable'
|
||||||
|
, assign
|
||||||
|
, assign'
|
||||||
|
, assignAdd
|
||||||
|
, assignAdd'
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Text.Encoding (encodeUtf8)
|
||||||
|
import Lens.Family2 ((.~), (&))
|
||||||
|
import TensorFlow.Core
|
||||||
|
import TensorFlow.Build (opDef)
|
||||||
|
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||||
|
import TensorFlow.Output (opInputs, unNodeName)
|
||||||
|
import TensorFlow.Tensor (tensorNodeName)
|
||||||
|
import TensorFlow.Types (tensorType)
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
import TensorFlow.Ops (zeros)
|
||||||
|
|
||||||
|
newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||||
|
|
||||||
|
-- | Creates a new, uninitialized variable.
|
||||||
|
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||||
|
variable = variable' id
|
||||||
|
|
||||||
|
variable' :: forall m a . (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Shape -> m (Variable a)
|
||||||
|
variable' params s = build $ do
|
||||||
|
-- Each variable needs a unique "shared_name". Use MonadFix to
|
||||||
|
-- set the attribute to the same name as the variable itself, without
|
||||||
|
-- exposing more internals of the Build module.
|
||||||
|
rec t <- CoreOps.varHandleOp' (params . (opAttr "shared_name" .~ n))
|
||||||
|
(tensorType (undefined :: a)) s
|
||||||
|
let n = encodeUtf8 $ unNodeName $ tensorNodeName t
|
||||||
|
return $ Variable t
|
||||||
|
|
||||||
|
-- | Creates a variable initialized to the given value.
|
||||||
|
-- Initialization happens next time session runs.
|
||||||
|
initializedVariable :: (MonadBuild m, TensorType a)
|
||||||
|
=> Tensor v a -> m (Variable a)
|
||||||
|
initializedVariable = initializedVariable' id
|
||||||
|
|
||||||
|
initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Tensor v a -> m (Variable a)
|
||||||
|
initializedVariable' params initializer = do
|
||||||
|
-- The shape is not known initially.
|
||||||
|
v@(Variable h) <- variable' params (Shape [])
|
||||||
|
i <- CoreOps.assignVariableOp h initializer
|
||||||
|
addInitializer =<< group i
|
||||||
|
return v
|
||||||
|
|
||||||
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
|
zeroInitializedVariable
|
||||||
|
:: (MonadBuild m, TensorType a, Num a) => Shape -> m (Variable a)
|
||||||
|
zeroInitializedVariable = zeroInitializedVariable' id
|
||||||
|
|
||||||
|
zeroInitializedVariable'
|
||||||
|
:: (MonadBuild m, TensorType a, Num a) => OpParams -> Shape -> m (Variable a)
|
||||||
|
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||||
|
|
||||||
|
-- | Gets the value stored in a variable.
|
||||||
|
readValue :: TensorType a => Variable a -> Tensor Build a
|
||||||
|
readValue = readValue' id
|
||||||
|
|
||||||
|
readValue' :: forall a . TensorType a
|
||||||
|
=> OpParams -> Variable a -> Tensor Build a
|
||||||
|
readValue' params (Variable h)
|
||||||
|
= pureOp [] $ do
|
||||||
|
os <- buildInputs h
|
||||||
|
pure $ opDef "ReadVariableOp"
|
||||||
|
& (params
|
||||||
|
. (opAttr "dtype" .~ tensorType (undefined :: a))
|
||||||
|
. (opInputs .~ os))
|
||||||
|
|
||||||
|
-- | Sets the value of a variable.
|
||||||
|
assign :: (MonadBuild m, TensorType a)
|
||||||
|
=> Variable a -> Tensor v a -> m ControlNode
|
||||||
|
assign = assign' id
|
||||||
|
|
||||||
|
assign' :: (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||||
|
assign' params (Variable h) v = CoreOps.assignVariableOp' params h v
|
||||||
|
|
||||||
|
-- | Increments the value of a variable.
|
||||||
|
assignAdd :: (MonadBuild m, TensorType a)
|
||||||
|
=> Variable a -> Tensor v a -> m ControlNode
|
||||||
|
assignAdd = assignAdd' id
|
||||||
|
|
||||||
|
assignAdd' :: (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Variable a -> Tensor v a -> m ControlNode
|
||||||
|
assignAdd' params (Variable h) v = CoreOps.assignAddVariableOp' params h v
|
|
@ -16,6 +16,7 @@ library
|
||||||
exposed-modules: TensorFlow.Gradient
|
exposed-modules: TensorFlow.Gradient
|
||||||
, TensorFlow.Ops
|
, TensorFlow.Ops
|
||||||
, TensorFlow.EmbeddingOps
|
, TensorFlow.EmbeddingOps
|
||||||
|
, TensorFlow.Variable
|
||||||
build-depends: proto-lens == 0.2.*
|
build-depends: proto-lens == 0.2.*
|
||||||
, base >= 4.7 && < 5
|
, base >= 4.7 && < 5
|
||||||
, bytestring
|
, bytestring
|
||||||
|
@ -37,6 +38,8 @@ Test-Suite RegressionTest
|
||||||
hs-source-dirs: tests
|
hs-source-dirs: tests
|
||||||
build-depends: base
|
build-depends: base
|
||||||
, HUnit
|
, HUnit
|
||||||
|
, lens-family
|
||||||
|
, transformers
|
||||||
, random
|
, random
|
||||||
, tensorflow
|
, tensorflow
|
||||||
, tensorflow-core-ops
|
, tensorflow-core-ops
|
||||||
|
@ -126,6 +129,23 @@ Test-Suite OpsTest
|
||||||
, transformers
|
, transformers
|
||||||
, vector
|
, vector
|
||||||
|
|
||||||
|
Test-Suite VariableTest
|
||||||
|
default-language: Haskell2010
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
main-is: VariableTest.hs
|
||||||
|
hs-source-dirs: tests
|
||||||
|
build-depends: HUnit
|
||||||
|
, base
|
||||||
|
, google-shim
|
||||||
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
|
, tensorflow-ops
|
||||||
|
, test-framework
|
||||||
|
, test-framework-hunit
|
||||||
|
, transformers
|
||||||
|
, vector
|
||||||
|
|
||||||
|
|
||||||
Test-Suite DataFlowOpsTest
|
Test-Suite DataFlowOpsTest
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
|
|
62
tensorflow-ops/tests/VariableTest.hs
Normal file
62
tensorflow-ops/tests/VariableTest.hs
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
module Main (main) where
|
||||||
|
|
||||||
|
import Control.Monad.IO.Class (liftIO)
|
||||||
|
import qualified Data.Vector.Storable as V
|
||||||
|
import Google.Test (googleTest)
|
||||||
|
import TensorFlow.Core (unScalar, run_, runSession, run)
|
||||||
|
import qualified TensorFlow.Ops as Ops
|
||||||
|
import TensorFlow.Variable
|
||||||
|
( readValue
|
||||||
|
, initializedVariable
|
||||||
|
, assign
|
||||||
|
, assignAdd
|
||||||
|
)
|
||||||
|
import Test.Framework (Test)
|
||||||
|
import Test.Framework.Providers.HUnit (testCase)
|
||||||
|
import Test.HUnit ((@=?))
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = googleTest [ testInitializedVariable
|
||||||
|
, testInitializedVariableShape
|
||||||
|
, testRereadRef
|
||||||
|
, testAssignAdd
|
||||||
|
]
|
||||||
|
|
||||||
|
testInitializedVariable :: Test
|
||||||
|
testInitializedVariable =
|
||||||
|
testCase "testInitializedVariable" $ runSession $ do
|
||||||
|
(formula, reset) <- do
|
||||||
|
v <- initializedVariable 42
|
||||||
|
r <- assign v 24
|
||||||
|
return (1 + readValue v, r)
|
||||||
|
result <- run formula
|
||||||
|
liftIO $ 43 @=? (unScalar result :: Float)
|
||||||
|
run_ reset -- Updates v to a different value
|
||||||
|
rerunResult <- run formula
|
||||||
|
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||||
|
|
||||||
|
testInitializedVariableShape :: Test
|
||||||
|
testInitializedVariableShape =
|
||||||
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
|
vector <- initializedVariable (Ops.constant [1] [42 :: Float])
|
||||||
|
result <- run (readValue vector)
|
||||||
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
|
-- | See https://github.com/tensorflow/haskell/issues/92.
|
||||||
|
-- Even though we're not explicitly evaluating `f0` until the end,
|
||||||
|
-- it should hold the earlier value of the variable.
|
||||||
|
testRereadRef :: Test
|
||||||
|
testRereadRef = testCase "testReRunAssign" $ runSession $ do
|
||||||
|
w <- initializedVariable 0
|
||||||
|
f0 <- run (readValue w)
|
||||||
|
run_ =<< assign w (Ops.scalar (0.1 :: Float))
|
||||||
|
f1 <- run (readValue w)
|
||||||
|
liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1)
|
||||||
|
|
||||||
|
testAssignAdd :: Test
|
||||||
|
testAssignAdd = testCase "testAssignAdd" $ runSession $ do
|
||||||
|
w <- initializedVariable 42
|
||||||
|
run_ =<< assignAdd w 17
|
||||||
|
f1 <- run (readValue w)
|
||||||
|
liftIO $ (42 + 17 :: Float) @=? unScalar f1
|
|
@ -62,6 +62,7 @@ module TensorFlow.Build
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
import Control.Monad.Catch (MonadThrow, MonadCatch, MonadMask)
|
||||||
|
import Control.Monad.Fix (MonadFix(..))
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||||
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
import Control.Monad.Trans.State.Strict(StateT(..), mapStateT, evalStateT)
|
||||||
|
@ -192,7 +193,8 @@ summaries = lens _summaries (\g x -> g { _summaries = x })
|
||||||
-- Used to manage build state internally as part of the @Session@ monad.
|
-- Used to manage build state internally as part of the @Session@ monad.
|
||||||
newtype BuildT m a = BuildT (StateT GraphState m a)
|
newtype BuildT m a = BuildT (StateT GraphState m a)
|
||||||
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans,
|
||||||
MonadState GraphState, MonadThrow, MonadCatch, MonadMask)
|
MonadState GraphState, MonadThrow, MonadCatch, MonadMask,
|
||||||
|
MonadFix)
|
||||||
|
|
||||||
-- | An action for building nodes in a TensorFlow graph.
|
-- | An action for building nodes in a TensorFlow graph.
|
||||||
type Build = BuildT Identity
|
type Build = BuildT Identity
|
||||||
|
@ -307,7 +309,6 @@ renderPendingNode (PendingNode scope pendingName nodeDef)
|
||||||
nextUnique .= succ u
|
nextUnique .= succ u
|
||||||
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
return $ nodeDef ^. op <> "_" <> Text.pack (show k)
|
||||||
|
|
||||||
|
|
||||||
-- | Turn an 'Output' into a string representation for the TensorFlow
|
-- | Turn an 'Output' into a string representation for the TensorFlow
|
||||||
-- foreign APIs.
|
-- foreign APIs.
|
||||||
encodeOutput :: Output -> Text
|
encodeOutput :: Output -> Text
|
||||||
|
|
|
@ -126,9 +126,6 @@ recordResult = do
|
||||||
put $! ResultState (i+1) ns
|
put $! ResultState (i+1) ns
|
||||||
return $! output i o
|
return $! output i o
|
||||||
|
|
||||||
instance BuildResult ResourceHandle where
|
|
||||||
buildResult = ResourceHandle <$> recordResult
|
|
||||||
|
|
||||||
instance Rendered v => BuildResult (Tensor v a) where
|
instance Rendered v => BuildResult (Tensor v a) where
|
||||||
buildResult = Tensor . pure <$> recordResult
|
buildResult = Tensor . pure <$> recordResult
|
||||||
|
|
||||||
|
@ -302,9 +299,6 @@ instance BuildInputs (ListOf (Tensor v) as) where
|
||||||
buildInputs Nil = return []
|
buildInputs Nil = return []
|
||||||
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
|
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
|
||||||
|
|
||||||
instance BuildInputs ResourceHandle where
|
|
||||||
buildInputs (ResourceHandle o) = return [o]
|
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||||
|
|
|
@ -52,6 +52,7 @@ module TensorFlow.Core
|
||||||
, addGraphDef
|
, addGraphDef
|
||||||
, opName
|
, opName
|
||||||
, opAttr
|
, opAttr
|
||||||
|
, addInitializer
|
||||||
-- * Tensor
|
-- * Tensor
|
||||||
, ControlNode
|
, ControlNode
|
||||||
, Tensor
|
, Tensor
|
||||||
|
@ -64,6 +65,7 @@ module TensorFlow.Core
|
||||||
, TensorType
|
, TensorType
|
||||||
, TensorData
|
, TensorData
|
||||||
, TensorDataType(decodeTensorData, encodeTensorData)
|
, TensorDataType(decodeTensorData, encodeTensorData)
|
||||||
|
, ResourceHandle
|
||||||
, Scalar(..)
|
, Scalar(..)
|
||||||
, Shape(..)
|
, Shape(..)
|
||||||
, OneOf
|
, OneOf
|
||||||
|
|
|
@ -33,7 +33,6 @@ module TensorFlow.Output
|
||||||
, Output(..)
|
, Output(..)
|
||||||
, output
|
, output
|
||||||
, PendingNodeName(..)
|
, PendingNodeName(..)
|
||||||
, ResourceHandle(..)
|
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
@ -127,8 +126,3 @@ instance IsString Output where
|
||||||
-> Output (fromInteger ix) $ assigned n
|
-> Output (fromInteger ix) $ assigned n
|
||||||
_ -> Output 0 $ assigned s
|
_ -> Output 0 $ assigned s
|
||||||
where assigned = NodeName . Text.pack
|
where assigned = NodeName . Text.pack
|
||||||
|
|
||||||
|
|
||||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
|
||||||
-- resources are variables.
|
|
||||||
newtype ResourceHandle = ResourceHandle Output
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ module TensorFlow.Types
|
||||||
, protoShape
|
, protoShape
|
||||||
, Attribute(..)
|
, Attribute(..)
|
||||||
, DataType(..)
|
, DataType(..)
|
||||||
|
, ResourceHandle
|
||||||
-- * Lists
|
-- * Lists
|
||||||
, ListOf(..)
|
, ListOf(..)
|
||||||
, List
|
, List
|
||||||
|
@ -94,15 +95,18 @@ import Proto.Tensorflow.Core.Framework.AttrValue
|
||||||
, shape
|
, shape
|
||||||
, tensor
|
, tensor
|
||||||
)
|
)
|
||||||
|
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
||||||
|
(ResourceHandle)
|
||||||
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
||||||
( TensorProto(..)
|
( TensorProto(..)
|
||||||
, floatVal
|
|
||||||
, doubleVal
|
|
||||||
, intVal
|
|
||||||
, stringVal
|
|
||||||
, int64Val
|
|
||||||
, stringVal
|
|
||||||
, boolVal
|
, boolVal
|
||||||
|
, doubleVal
|
||||||
|
, floatVal
|
||||||
|
, intVal
|
||||||
|
, int64Val
|
||||||
|
, resourceHandleVal
|
||||||
|
, stringVal
|
||||||
|
, stringVal
|
||||||
)
|
)
|
||||||
import Proto.Tensorflow.Core.Framework.TensorShape
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
( TensorShapeProto(..)
|
( TensorShapeProto(..)
|
||||||
|
@ -183,6 +187,10 @@ instance TensorType (Complex Double) where
|
||||||
tensorRefType _ = DT_COMPLEX128
|
tensorRefType _ = DT_COMPLEX128
|
||||||
tensorVal = error "TODO (Complex Double)"
|
tensorVal = error "TODO (Complex Double)"
|
||||||
|
|
||||||
|
instance TensorType ResourceHandle where
|
||||||
|
tensorType _ = DT_RESOURCE
|
||||||
|
tensorRefType _ = DT_RESOURCE_REF
|
||||||
|
tensorVal = resourceHandleVal
|
||||||
|
|
||||||
-- | Tensor data with the correct memory layout for tensorflow.
|
-- | Tensor data with the correct memory layout for tensorflow.
|
||||||
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
||||||
|
|
Loading…
Add table
Reference in a new issue