mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Merge branch 'master' of https://github.com/tensorflow/haskell into webarchive
This commit is contained in:
commit
ce4902e8ac
|
@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
|
|||
let x = TF.vector xData
|
||||
y = TF.vector yData
|
||||
-- Create scalar variables for slope and intercept.
|
||||
w <- TF.build (TF.initializedVariable 0)
|
||||
b <- TF.build (TF.initializedVariable 0)
|
||||
w <- TF.initializedVariable 0
|
||||
b <- TF.initializedVariable 0
|
||||
-- Define the loss function.
|
||||
let yHat = (x `TF.mul` w) `TF.add` b
|
||||
loss = TF.square (yHat `TF.sub` y)
|
||||
-- Optimize with gradient descent.
|
||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||
replicateM_ 1000 (TF.run trainStep)
|
||||
-- Return the learned parameters.
|
||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||
|
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
|
|||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Build TF.ControlNode
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
let applyGrad param grad =
|
||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||
|
|
|
@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
|
|||
(prettyLazyText 80 $ docOpList flags x)
|
||||
|
||||
blackList =
|
||||
-- A few data flow ops take a list of heterogeneous
|
||||
-- parameters which we don't support in general form.
|
||||
[ "HashTable"
|
||||
, "MutableDenseHashTable"
|
||||
, "MutableHashTable"
|
||||
, "MutableHashTableOfTensors"
|
||||
, "QueueDequeue"
|
||||
, "QueueDequeueMany"
|
||||
, "QueueDequeueUpTo"
|
||||
, "Stack"
|
||||
, "TensorArray"
|
||||
, "TensorArrayV2"
|
||||
, "QueueEnqueueManyV2"
|
||||
, "QueueDequeueV2"
|
||||
, "QueueDequeueUpToV2"
|
||||
, "QueueEnqueueV2"
|
||||
, "QueueDequeueManyV2"
|
||||
, "Stage"
|
||||
, "Unstage"
|
||||
-- These should be possible to support by adding a bunch of
|
||||
-- overloads with a variable number of tuple arguments.
|
||||
, "Assert"
|
||||
, "BarrierTakeMany"
|
||||
, "Print"
|
||||
, "QueueEnqueue"
|
||||
, "QueueEnqueueMany"
|
||||
-- Need list of types support.
|
||||
, "DecodeCSV"
|
||||
, "ParseExample"
|
||||
, "ParseSingleSequenceExample"
|
||||
, "RestoreV2"
|
||||
, "Save"
|
||||
, "SaveV2"
|
||||
, "SaveSlices"
|
||||
, "SymbolicGradient"
|
||||
, "_ArrayToList"
|
||||
, "_ListToArray"
|
||||
[ -- Requires the "func" type:
|
||||
"SymbolicGradient"
|
||||
-- Easy: support larger result tuples.
|
||||
, "ParseSingleSequenceExample"
|
||||
, "Skipgram"
|
||||
]
|
||||
|
|
|
@ -49,7 +49,7 @@ import TensorFlow.Tensor
|
|||
)
|
||||
import TensorFlow.Ops
|
||||
import TensorFlow.Session
|
||||
(runSession, run, run_, runWithFeeds, build, buildAnd)
|
||||
(runSession, run, run_, runWithFeeds, build)
|
||||
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -108,7 +108,7 @@ testGraphDefExec :: Test
|
|||
testGraphDefExec = testCase "testGraphDefExec" $ do
|
||||
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
|
||||
runSession $ do
|
||||
build $ addGraphDef graphDef
|
||||
addGraphDef graphDef
|
||||
x <- run $ tensorFromName ValueKind "Mul_2"
|
||||
liftIO $ (50 :: Float) @=? unScalar x
|
||||
|
||||
|
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
|
|||
wtsCkptPath <- liftIO wtsCkpt
|
||||
biasCkptPath <- liftIO biasCkpt
|
||||
-- Run those restoring nodes on the graph in the current session.
|
||||
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
|
||||
run_ =<< (sequence :: Monad m => [m a] -> m [a])
|
||||
[ restore wtsCkptPath wts
|
||||
, restoreFromName biasCkptPath "bias" bias
|
||||
]
|
||||
|
|
|
@ -23,7 +23,7 @@ module TensorFlow.NN
|
|||
import Prelude hiding ( log
|
||||
, exp
|
||||
)
|
||||
import TensorFlow.Build ( Build
|
||||
import TensorFlow.Build ( MonadBuild
|
||||
, render
|
||||
, withNameScope
|
||||
)
|
||||
|
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
|
|||
--
|
||||
-- `logits` and `targets` must have the same type and shape.
|
||||
sigmoidCrossEntropyWithLogits
|
||||
:: (OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
|
||||
=> Tensor Value a -- ^ __logits__
|
||||
-> Tensor Value a -- ^ __targets__
|
||||
-> Build (Tensor Value a)
|
||||
-> m (Tensor Value a)
|
||||
sigmoidCrossEntropyWithLogits logits targets = do
|
||||
logits' <- render logits
|
||||
targets' <- render targets
|
||||
|
|
|
@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
|
|||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.NN as TF
|
||||
|
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
|
|||
|
||||
assertAllClose (head r) (V.fromList [0.5, -0.5])
|
||||
|
||||
run :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
run = TF.runSession . TF.buildAnd TF.run
|
||||
run :: TF.Fetchable t a => TF.Session t -> IO a
|
||||
run = TF.runSession . (>>= TF.run)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testGradientAtZero
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
@ -147,6 +148,7 @@ imports = stack [
|
|||
"import Data.ByteString (ByteString)"
|
||||
, "import Data.Complex (Complex)"
|
||||
, "import Data.Int (Int8, Int16, Int32, Int64)"
|
||||
, "import Data.Proxy (Proxy(Proxy))"
|
||||
, "import Data.Word (Word8, Word16)"
|
||||
, "import Lens.Family2 ((.~), (&))"
|
||||
, "import TensorFlow.Build"
|
||||
|
@ -171,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
|
|||
renderOp :: ParsedOp -> Doc
|
||||
renderOp pOp = stack $
|
||||
[ haddocks
|
||||
, n <+> "::" <+> hang 0 (typeSig pOp)
|
||||
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
-- Prevent unreasonably long compilation times on ghc-7.10, due
|
||||
-- to stack calling "-dump-hi" which (unnecessarily) includes the
|
||||
-- inlining information, and is large for ops with many arguments.
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
, "{-# NOINLINE " <> n <> "#-}"
|
||||
#endif
|
||||
, n <+> "::" <+> hang 0 (typeSig empty pOp)
|
||||
, n <+> "=" <+> n <> "' id"
|
||||
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
|
||||
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
|
||||
<+> "=" </> -- args are indented
|
||||
-- the body needs to be indented wrt the name
|
||||
indent indentation (functionBody pOp)
|
||||
] ++ whereClause listSizeAttrs
|
||||
where
|
||||
n = renderHaskellName $ parsedOpName pOp
|
||||
n' = n <> "'"
|
||||
listSizeAttrs = inferredListSizeAttrs pOp
|
||||
args = sep $ map renderHaskellName
|
||||
$ map attrName (explicitInputAttrs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp)
|
||||
args = sep $ "op'options"
|
||||
: (map renderHaskellName
|
||||
$ map attrName (explicitInputAttrs pOp)
|
||||
++ map parsedArgName (parsedInputs pOp))
|
||||
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
|
||||
|
||||
-- | A check that all lists of the given size have the given length.
|
||||
|
@ -210,15 +222,21 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
|||
whereClause [] = []
|
||||
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||
where
|
||||
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
||||
defineLengthAttr a = renderHaskellAttrName a <+> "="
|
||||
<+> "fromIntegral (length"
|
||||
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||
<> ") :: Int64"
|
||||
|
||||
renderHaskellAttrName :: Attr a -> Doc
|
||||
renderHaskellAttrName = renderHaskellName . attrName
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
where
|
||||
maybeLift
|
||||
| parsedOpIsMonadic pOp = "build $"
|
||||
| otherwise = ""
|
||||
buildFunction
|
||||
| null outputListsSizes = "buildOp"
|
||||
| otherwise = "buildListOp" <+>
|
||||
|
@ -229,9 +247,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
<- parsedOutputs pOp]
|
||||
buildOpParts =
|
||||
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
|
||||
-- Renders tensor arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+>
|
||||
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
|
||||
-- Renders type parameter arguments.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
|
||||
| a <- inferredTypeAttrs pOp, let n = attrName a
|
||||
] ++
|
||||
-- Renders mandatory attributes as function parameters.
|
||||
|
@ -241,9 +258,17 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
-- Renders sizes of tensor list types having number_attr.
|
||||
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
|
||||
| a <- inferredListSizeAttrs pOp, let n = attrName a
|
||||
]
|
||||
] ++
|
||||
["& op'options"]
|
||||
|
||||
|
||||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
inferredTypeExpr a
|
||||
| typeParamIsList $ attrInfo a
|
||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
|
||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||
-- debugging.
|
||||
|
@ -258,23 +283,28 @@ extras d = enclose "{-\n" "\n-}" $
|
|||
-- | The type signature for an op.
|
||||
-- Of the form:
|
||||
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
|
||||
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||
-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
|
||||
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
|
||||
-- "Tensor t2 v2" is an output.
|
||||
typeSig :: ParsedOp -> Doc
|
||||
typeSig pOp = constraints
|
||||
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
typeSig :: Doc -> ParsedOp -> Doc
|
||||
typeSig pre pOp = constraints
|
||||
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
|
||||
++ map tensorArgAndComment (parsedInputs pOp)
|
||||
++ [outputs])
|
||||
where
|
||||
constraints
|
||||
| null (inferredTypeAttrs pOp) = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
| null classConstraints = empty
|
||||
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ concatMap tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
++ if parsedOpIsMonadic pOp then ["m'"] else []
|
||||
-- Use m' as the type parameter to avoid clashing with an attribute name.
|
||||
monadConstraint
|
||||
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
|
||||
| otherwise = []
|
||||
classConstraints = monadConstraint ++ map tensorArgConstraint
|
||||
(inferredTypeAttrs pOp)
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
|
||||
renderAttrType (AttrSingle a) = renderAttrBaseType a
|
||||
|
@ -295,7 +325,7 @@ typeSig pOp = constraints
|
|||
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
|
||||
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
|
||||
wrapOutput o
|
||||
| parsedOpIsMonadic pOp = "Build" <+> parens o
|
||||
| parsedOpIsMonadic pOp = "m'" <+> parens o
|
||||
| otherwise = o
|
||||
|
||||
-- | Render an op input or output.
|
||||
|
@ -305,17 +335,18 @@ tensorArg p = case parsedArgCase p of
|
|||
ResourceArg -> "ResourceHandle"
|
||||
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
|
||||
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
|
||||
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
|
||||
MixedListArg {argTypeAttr = t, argCaseKind = k}
|
||||
-> "TensorList" <+> kind k <+> renderHaskellName t
|
||||
where
|
||||
kind k = case k of
|
||||
ArgTensorRef -> "Ref"
|
||||
ArgTensorValue -> "Value"
|
||||
ArgTensorEither v' -> strictText v'
|
||||
tensorType t k = let
|
||||
v = case k of
|
||||
ArgTensorRef -> "Tensor Ref"
|
||||
ArgTensorValue -> "Tensor Value"
|
||||
ArgTensorEither v' -> "Tensor" <+> strictText v'
|
||||
a = case t of
|
||||
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
|
||||
ArgTypeAttr n -> renderHaskellName n
|
||||
in v <+> a
|
||||
in "Tensor" <+> kind k <+> a
|
||||
|
||||
attrComment :: Attr a -> Doc
|
||||
attrComment a = argComment' (attrName a) (attrDescription a)
|
||||
|
@ -347,18 +378,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
|
|||
]
|
||||
|
||||
-- | Constraints for a given type parameter.
|
||||
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
|
||||
tensorArgConstraint :: Attr [DataType] -> [Doc]
|
||||
tensorArgConstraint a
|
||||
= ("TensorType" <+> n
|
||||
: if null typeList
|
||||
then []
|
||||
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
|
||||
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
|
||||
-- or "TensorTypes ts" or "OneOfs [..] ts".
|
||||
tensorArgConstraint :: Attr TypeParam -> Doc
|
||||
tensorArgConstraint a = case attrInfo a of
|
||||
TypeParam False Nothing -> "TensorType" <+> n
|
||||
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
|
||||
TypeParam True Nothing -> "TensorTypes" <+> n
|
||||
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
|
||||
where
|
||||
n = renderHaskellName $ attrName a
|
||||
typeList = map strictText $
|
||||
Set.toList $ Set.fromList $
|
||||
map dtTypeToHaskell $ attrInfo a
|
||||
n = renderHaskellAttrName a
|
||||
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
|
||||
typeList = ("'" <>) . brackets . commasep . map strictText .
|
||||
Set.toList . Set.fromList .
|
||||
map dtTypeToHaskell . toList
|
||||
|
||||
-- NOTE: The cases of this function should be kept in sync with
|
||||
-- TensorFlow.Types.AllTensorTypes.
|
||||
|
|
|
@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
|
|||
, Attr(..)
|
||||
, AttrType(..)
|
||||
, AttrBaseType(..)
|
||||
, TypeParam(..)
|
||||
, ParsedArg(..)
|
||||
, ParsedArgCase(..)
|
||||
, ArgType(..)
|
||||
|
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
|
|||
, explicitInputAttrs :: [Attr AttrType]
|
||||
-- ^ Attributes that must be set explicitly when creating the op.
|
||||
-- Associated with the type of the attribute.
|
||||
, inferredTypeAttrs :: [Attr [DataType]]
|
||||
, inferredTypeAttrs :: [Attr TypeParam]
|
||||
-- ^ Attributes that are type parameters.
|
||||
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If this list is empty, then any type is acceptable.
|
||||
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
|
||||
-- Attributes which are list sizes (ints) that are inferred automatically
|
||||
-- from one or more of the input tensors.
|
||||
|
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
|||
| AttrType | AttrShape | AttrTensor
|
||||
deriving Eq
|
||||
|
||||
data TypeParam = TypeParam
|
||||
{ typeParamIsList :: Bool
|
||||
, typeParamRestrictions :: Maybe (NonEmpty DataType)
|
||||
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If 'Nothing', then any type is acceptable.
|
||||
}
|
||||
|
||||
-- | An input or output argument (Tensor) for an op.
|
||||
data ParsedArg = ParsedArg
|
||||
{ parsedArgName :: Name
|
||||
|
@ -120,7 +126,6 @@ data ParsedArgCase
|
|||
}
|
||||
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
|
||||
-- ^ A heterogeneous list.
|
||||
-- TODO(judahjacobson): Implement this.
|
||||
| ResourceArg
|
||||
|
||||
argKind :: ParsedArgCase -> Maybe ArgKind
|
||||
|
@ -223,11 +228,6 @@ parseOp o = ParsedOp
|
|||
(o ^. inputArg) tensorKindParams
|
||||
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
|
||||
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
|
||||
-- Type attributes that can be inferred from at least one input or output.
|
||||
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
|
||||
$ parsedInputs ++ parsedOutputs
|
||||
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
|
||||
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
|
||||
-- Integer attributes that can be inferred from the size of at least one
|
||||
-- input list.
|
||||
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
|
||||
|
@ -235,10 +235,14 @@ parseOp o = ParsedOp
|
|||
implicitAttrs = Set.fromList $ map tfName $
|
||||
map attrName inferredTypeAttrs
|
||||
++ map attrName inferredListSizeAttrs
|
||||
-- Attributes that can't be inferred and don't have defaults, so must be passed
|
||||
-- as separate arguments to the op.
|
||||
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
|
||||
argTypeParams = Set.fromList $ map tfName $
|
||||
mapMaybe (getArgTypeParam . parsedArgCase) $
|
||||
parsedInputs ++ parsedOutputs
|
||||
-- Attributes that can't be inferred and don't have defaults, so must be
|
||||
-- passed as separate arguments to the op.
|
||||
explicitInputAttrs = sortBy (comparing (tfName . attrName))
|
||||
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
|
||||
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
|
||||
$ o ^. attr
|
||||
|
||||
-- TODO(judahjacobson): Some arguments should be refs.
|
||||
|
@ -252,29 +256,30 @@ outputTensorKind a
|
|||
| a ^. isRef = ArgTensorRef
|
||||
| otherwise = ArgTensorValue
|
||||
|
||||
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr implicitAttrs a
|
||||
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
|
||||
getExplicitInputAttr o implicitAttrs a
|
||||
| TFName (a ^. name) `Set.notMember` implicitAttrs
|
||||
, a ^. maybe'defaultValue == Nothing
|
||||
, t <- parseAttrType (a ^. type')
|
||||
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
|
||||
, t <- parseAttrType o (a ^. type')
|
||||
, t `elem` map AttrSingle
|
||||
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
|
||||
++ [AttrList AttrType] = Just t
|
||||
| otherwise = Nothing
|
||||
|
||||
-- | The type attribute used by this input or output (if any).
|
||||
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
|
||||
parsedArgTypeAttr p = case parsedArgCase p of
|
||||
ResourceArg -> Nothing
|
||||
SimpleArg {argType = t} -> fromArgType t
|
||||
ListArg {argType = t} -> fromArgType t
|
||||
MixedListArg {argTypeAttr = n} -> Just $ tfName n
|
||||
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
|
||||
getInferredTypeAttr argTypeParams a
|
||||
| TFName (a ^. name) `notElem` argTypeParams = Nothing
|
||||
| a ^. type' == "type" = Just $ TypeParam False allowed
|
||||
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
|
||||
| otherwise = Nothing
|
||||
where
|
||||
fromArgType (ArgTypeAttr n) = Just $ tfName n
|
||||
fromArgType _ = Nothing
|
||||
allowed = nonEmpty (a ^. allowedValues . list . type')
|
||||
|
||||
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
|
||||
getInferredTypeAttr a
|
||||
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
|
||||
| otherwise = Nothing
|
||||
getArgTypeParam :: ParsedArgCase -> Maybe Name
|
||||
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
|
||||
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
|
||||
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
|
||||
getArgTypeParam _ = Nothing
|
||||
|
||||
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
|
||||
getInferredListSizeAttr inputs a
|
||||
|
@ -285,7 +290,7 @@ getInferredListSizeAttr inputs a
|
|||
} <- inputs
|
||||
, TFName (a ^. name) == tfName n]
|
||||
| otherwise = Nothing
|
||||
|
||||
|
||||
-- | Like mapMaybe, but associates the attribute name/description with the given info.
|
||||
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
|
||||
mapMaybeAttrs f = mapMaybe $ \a -> do
|
||||
|
@ -295,7 +300,7 @@ mapMaybeAttrs f = mapMaybe $ \a -> do
|
|||
, attrDescription = a ^. description
|
||||
, attrInfo = x
|
||||
}
|
||||
|
||||
|
||||
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
|
||||
parseArg a tKind = ParsedArg
|
||||
{ parsedArgName = makeName (a ^. name)
|
||||
|
@ -317,15 +322,15 @@ parseArgCase a tKind
|
|||
maybeAttr "" = Nothing
|
||||
maybeAttr t = Just $ makeName t
|
||||
|
||||
parseAttrType :: Text -> AttrType
|
||||
parseAttrType = \case
|
||||
parseAttrType :: OpDef -> Text -> AttrType
|
||||
parseAttrType o = \case
|
||||
"string" -> AttrSingle AttrBytes
|
||||
"int" -> AttrSingle AttrInt64
|
||||
"float" -> AttrSingle AttrFloat
|
||||
"bool" -> AttrSingle AttrBool
|
||||
"type" -> AttrSingle AttrType
|
||||
"shape" -> AttrSingle AttrShape
|
||||
"tensor" -> AttrSingle AttrTensor
|
||||
"int" -> AttrSingle AttrInt64
|
||||
"float" -> AttrSingle AttrFloat
|
||||
"bool" -> AttrSingle AttrBool
|
||||
"type" -> AttrSingle AttrType
|
||||
"shape" -> AttrSingle AttrShape
|
||||
"tensor" -> AttrSingle AttrTensor
|
||||
"list(string)" -> AttrList AttrBytes
|
||||
"list(int)" -> AttrList AttrInt64
|
||||
"list(float)" -> AttrList AttrFloat
|
||||
|
@ -334,3 +339,4 @@ parseAttrType = \case
|
|||
"list(shape)" -> AttrList AttrShape
|
||||
"list(tensor)" -> AttrList AttrTensor
|
||||
t -> error $ "parseAttrType: unrecognized type " ++ show t
|
||||
++ " for op " ++ show (o ^. name)
|
||||
|
|
|
@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
|
|||
|
||||
import Control.Monad (zipWithM)
|
||||
import Data.Int (Int32, Int64)
|
||||
import TensorFlow.Build (Build, colocateWith, render)
|
||||
import TensorFlow.Build (MonadBuild, colocateWith, render)
|
||||
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
|
||||
import TensorFlow.Tensor (Tensor, Value)
|
||||
import TensorFlow.Types (OneOf, TensorType)
|
||||
|
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
|
|||
--
|
||||
-- The results of the lookup are concatenated into a dense
|
||||
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup :: forall a b v .
|
||||
( TensorType a
|
||||
embeddingLookup :: forall a b v m .
|
||||
( MonadBuild m
|
||||
, TensorType a
|
||||
, OneOf '[Int64, Int32] b
|
||||
, Num b
|
||||
)
|
||||
|
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
|
|||
-- containing the ids to be looked up in `params`.
|
||||
-- The ids are required to have fewer than 2^31
|
||||
-- entries.
|
||||
-> Build (Tensor Value a)
|
||||
-> m (Tensor Value a)
|
||||
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
|
||||
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
|
||||
embeddingLookup params@(p0 : _) ids = do
|
||||
|
|
|
@ -56,7 +56,9 @@ import qualified Data.Text as Text
|
|||
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Build
|
||||
( Build
|
||||
( MonadBuild
|
||||
, Build
|
||||
, build
|
||||
, render
|
||||
, renderNodeName
|
||||
, renderedNodeDefs
|
||||
|
@ -70,6 +72,7 @@ import TensorFlow.Ops
|
|||
, expandDims
|
||||
, fill
|
||||
, matMul
|
||||
, matMul'
|
||||
, reducedShape
|
||||
, reluGrad
|
||||
, reshape
|
||||
|
@ -93,7 +96,6 @@ import TensorFlow.Tensor
|
|||
, TensorKind (ValueKind)
|
||||
, Value
|
||||
, tensorOutput
|
||||
, tensorAttr
|
||||
)
|
||||
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
|
||||
import Proto.Tensorflow.Core.Framework.NodeDef
|
||||
|
@ -111,16 +113,17 @@ type GradientCompatible a =
|
|||
|
||||
|
||||
-- | Gradient of @y@ w.r.t. each element of @xs@.
|
||||
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
|
||||
gradients :: forall a v1 v2 m . (MonadBuild m
|
||||
, Num (Tensor v1 a)
|
||||
-- TODO(gnezdo): remove indirect constraint.
|
||||
-- It's a wart inherited from Num instance.
|
||||
-- It's a wart inherited from Num instance.
|
||||
, v1 ~ Value
|
||||
, GradientCompatible a
|
||||
)
|
||||
=> Tensor v1 a -- ^ The output of the graph.
|
||||
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
|
||||
-> Build [Tensor Value a]
|
||||
gradients y xs = do
|
||||
-> m [Tensor Value a]
|
||||
gradients y xs = build $ do
|
||||
-- The gradients are computed using "reverse accumulation", similarly to
|
||||
-- what is described here:
|
||||
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
|
||||
|
@ -529,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
let transposeA = lookupAttr nodeDef "transpose_a"
|
||||
transposeB = lookupAttr nodeDef "transpose_b"
|
||||
transAttrs a b =
|
||||
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
|
||||
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
|
||||
in case (transposeA, transposeB) of
|
||||
(False, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
[ Just $ matMul' (transAttrs False True) dz y
|
||||
, Just $ matMul' (transAttrs True False) x dz]
|
||||
(False, True) ->
|
||||
[ Just $ dz `matMul` y
|
||||
, Just $ (x `matMul` dz) & transAttrs True False ]
|
||||
[ Just $ matMul dz y
|
||||
, Just $ matMul' (transAttrs True False) x dz]
|
||||
(True, False) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs False True
|
||||
, Just $ x `matMul` dz ]
|
||||
[ Just $ matMul' (transAttrs False True) dz y
|
||||
, Just $ matMul x dz]
|
||||
(True, True) ->
|
||||
[ Just $ (dz `matMul` y) & transAttrs True True
|
||||
, Just $ (x `matMul` dz) & transAttrs True True ]
|
||||
[ Just $ matMul' (transAttrs True True) dz y
|
||||
, Just $ matMul' (transAttrs True True) x dz]
|
||||
|
||||
opGrad "Transpose" _ [_, toT -> p] [dz] =
|
||||
[ Just $ CoreOps.transpose dz
|
||||
|
@ -551,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
|
|||
]
|
||||
|
||||
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
||||
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
[ Just $ CoreOps.conv2DBackpropInput'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
(shape x) y dz
|
||||
, Just $ CoreOps.conv2DBackpropFilter'
|
||||
((opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x (shape y) dz
|
||||
]
|
||||
where
|
||||
strides = lookupAttr nodeDef "strides" :: [Int64]
|
||||
|
@ -569,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
|
|||
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
|
||||
|
||||
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
|
||||
[ Just $ CoreOps.maxPoolGrad x output dz
|
||||
& tensorAttr "ksize" .~ ksize
|
||||
& tensorAttr "strides" .~ strides
|
||||
& tensorAttr "padding" .~ padding
|
||||
& tensorAttr "data_format" .~ dataFormat
|
||||
[ Just $ CoreOps.maxPoolGrad'
|
||||
((opAttr "ksize" .~ ksize)
|
||||
. (opAttr "strides" .~ strides)
|
||||
. (opAttr "padding" .~ padding)
|
||||
. (opAttr "data_format" .~ dataFormat))
|
||||
x output dz
|
||||
]
|
||||
where
|
||||
output :: Tensor Value a
|
||||
|
|
|
@ -58,56 +58,99 @@
|
|||
|
||||
module TensorFlow.Ops
|
||||
( CoreOps.add
|
||||
, CoreOps.add'
|
||||
, CoreOps.abs
|
||||
, CoreOps.abs'
|
||||
, CoreOps.addN
|
||||
, CoreOps.addN'
|
||||
, CoreOps.argMax
|
||||
, CoreOps.argMax'
|
||||
, CoreOps.assign
|
||||
, CoreOps.assign'
|
||||
, CoreOps.broadcastGradientArgs
|
||||
, CoreOps.broadcastGradientArgs'
|
||||
, CoreOps.cast
|
||||
, CoreOps.cast'
|
||||
, CoreOps.concat
|
||||
, CoreOps.concat'
|
||||
, constant
|
||||
, constant'
|
||||
, CoreOps.equal
|
||||
, CoreOps.equal'
|
||||
, expandDims
|
||||
, expandDims'
|
||||
, initializedVariable
|
||||
, initializedVariable'
|
||||
, zeroInitializedVariable
|
||||
, zeroInitializedVariable'
|
||||
, CoreOps.fill
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.fill'
|
||||
, CoreOps.identity
|
||||
, CoreOps.identity'
|
||||
, CoreOps.matMul
|
||||
, CoreOps.matMul'
|
||||
, matTranspose
|
||||
, matTranspose'
|
||||
, CoreOps.mean
|
||||
, CoreOps.mean'
|
||||
, CoreOps.mul
|
||||
, CoreOps.mul'
|
||||
, CoreOps.neg
|
||||
, CoreOps.neg'
|
||||
, CoreOps.oneHot
|
||||
, CoreOps.oneHot'
|
||||
, CoreOps.pack
|
||||
, CoreOps.pack'
|
||||
, placeholder
|
||||
, placeholder'
|
||||
, CoreOps.range
|
||||
, CoreOps.range'
|
||||
, reducedShape
|
||||
, CoreOps.relu
|
||||
, CoreOps.relu'
|
||||
, CoreOps.reluGrad
|
||||
, CoreOps.reluGrad'
|
||||
, CoreOps.reshape
|
||||
, CoreOps.reshape'
|
||||
, restore
|
||||
, restoreFromName
|
||||
, save
|
||||
, scalar
|
||||
, scalar'
|
||||
, shape
|
||||
, shape'
|
||||
, CoreOps.sign
|
||||
, CoreOps.sign'
|
||||
, CoreOps.size
|
||||
, CoreOps.size'
|
||||
, CoreOps.softmax
|
||||
, CoreOps.softmax'
|
||||
, CoreOps.softmaxCrossEntropyWithLogits
|
||||
, CoreOps.softmaxCrossEntropyWithLogits'
|
||||
, CoreOps.sparseToDense
|
||||
, CoreOps.sparseToDense'
|
||||
, CoreOps.sub
|
||||
, CoreOps.sub'
|
||||
, CoreOps.sum
|
||||
, CoreOps.sum'
|
||||
, CoreOps.transpose
|
||||
, CoreOps.transpose'
|
||||
, truncatedNormal
|
||||
, truncatedNormal'
|
||||
, CoreOps.variable
|
||||
, CoreOps.variable'
|
||||
, vector
|
||||
, vector'
|
||||
, zeros
|
||||
, CoreOps.zerosLike
|
||||
, CoreOps.zerosLike'
|
||||
, scalarize
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Complex (Complex)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Data.Word (Word16)
|
||||
import Prelude hiding (abs, sum, concat)
|
||||
import Data.ProtoLens (def)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
|
@ -151,60 +194,73 @@ instance ( TensorType a
|
|||
signum = CoreOps.sign
|
||||
negate = CoreOps.neg
|
||||
|
||||
matTranspose :: forall a v . TensorType a
|
||||
=> Tensor v a -> Tensor Value a
|
||||
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
|
||||
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
|
||||
matTranspose = matTranspose' id
|
||||
|
||||
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
|
||||
placeholder shape' =
|
||||
buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ shape'
|
||||
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
|
||||
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||
|
||||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||
placeholder = placeholder' id
|
||||
|
||||
placeholder' :: forall m a . (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Shape -> m (Tensor Value a)
|
||||
placeholder' params pShape
|
||||
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
||||
-- and thus would be CSE'd.
|
||||
= build $ buildOp $ opDef "Placeholder"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "shape" .~ pShape
|
||||
& params
|
||||
|
||||
-- | Creates a variable initialized to the given value.
|
||||
-- Initialization happens next time session runs.
|
||||
initializedVariable :: forall a . TensorType a
|
||||
=> Tensor Value a -> Build (Tensor Ref a)
|
||||
initializedVariable initializer = do
|
||||
v <- CoreOps.variable [] -- The shape is not known initially.
|
||||
(i :: Tensor Ref a) <-
|
||||
buildOp (opDef "Assign"
|
||||
& opAttr "T" .~ tensorType (undefined :: a)
|
||||
& opAttr "use_locking" .~ True
|
||||
& opAttr "validate_shape" .~ False
|
||||
)
|
||||
v initializer
|
||||
initializedVariable :: (MonadBuild m, TensorType a)
|
||||
=> Tensor Value a -> m (Tensor Ref a)
|
||||
initializedVariable = initializedVariable' id
|
||||
|
||||
initializedVariable' :: (MonadBuild m, TensorType a)
|
||||
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
|
||||
initializedVariable' params initializer = do
|
||||
v <- CoreOps.variable' params [] -- The shape is not known initially.
|
||||
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
|
||||
initializer
|
||||
addInitializer =<< group i
|
||||
return v
|
||||
|
||||
-- | Creates a zero-initialized variable with the given shape.
|
||||
zeroInitializedVariable
|
||||
:: (TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = initializedVariable . zeros
|
||||
:: (MonadBuild m, TensorType a, Num a) =>
|
||||
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable = zeroInitializedVariable' id
|
||||
|
||||
zeroInitializedVariable'
|
||||
:: (MonadBuild m, TensorType a, Num a) =>
|
||||
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
|
||||
zeroInitializedVariable' params = initializedVariable' params . zeros
|
||||
|
||||
-- TODO: Support heterogeneous list of tensors.
|
||||
save :: forall a v . TensorType a
|
||||
save :: forall a m v . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> [Tensor v a] -- ^ Tensors to save.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
save path xs = do
|
||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
||||
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
|
||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||
let saveOp = buildOp $ opDef "Save"
|
||||
& opAttr "T" .~ types
|
||||
saveOp (scalar path) (CoreOps.pack names) xs
|
||||
build $ saveOp (scalar path) (CoreOps.pack names) xs
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
--
|
||||
-- This version allows restoring from a checkpoint file that uses a different
|
||||
-- tensor name than the variable.
|
||||
restoreFromName :: forall a . TensorType a
|
||||
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> ByteString -- ^ Tensor name override.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
restoreFromName path name x = do
|
||||
let restoreOp = buildOp $ opDef "Restore"
|
||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||
|
@ -212,12 +268,12 @@ restoreFromName path name x = do
|
|||
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
||||
|
||||
-- | Restore a tensor's value from a checkpoint file.
|
||||
restore :: forall a . TensorType a
|
||||
restore :: forall a m . (MonadBuild m, TensorType a)
|
||||
=> ByteString -- ^ File path.
|
||||
-> Tensor Ref a -- ^ Tensor to restore.
|
||||
-> Build ControlNode
|
||||
-> m ControlNode
|
||||
restore path x = do
|
||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
||||
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
|
||||
restoreFromName path name x
|
||||
|
||||
-- | Create a constant tensor.
|
||||
|
@ -227,27 +283,27 @@ restore path x = do
|
|||
-- element 0: index (0, ..., 0)
|
||||
-- element 1: index (0, ..., 1)
|
||||
-- ...
|
||||
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant (Shape shape') values
|
||||
constant :: TensorType a => Shape -> [a] -> Tensor Value a
|
||||
constant = constant' id
|
||||
|
||||
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
|
||||
constant' params (Shape cShape) values
|
||||
| invalidLength = error invalidLengthMsg
|
||||
| otherwise = buildOp $ opDef "Const"
|
||||
& opAttr "value" .~ typedNode
|
||||
& opAttr "dtype" .~ nodeType
|
||||
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
|
||||
where
|
||||
invalidLength = product shape' /= fromIntegral (length values)
|
||||
invalidLength = product cShape /= fromIntegral (length values)
|
||||
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
|
||||
(product shape')
|
||||
(product cShape)
|
||||
(length values)
|
||||
nodeType = tensorType (undefined :: a)
|
||||
typedNode :: TensorProto
|
||||
typedNode = def
|
||||
& dtype .~ nodeType
|
||||
& dtype .~ tensorType (undefined :: a)
|
||||
& tensorShape.TensorShape.dim .~
|
||||
[def & TensorShape.size .~ x | x <- shape']
|
||||
[def & TensorShape.size .~ x | x <- cShape]
|
||||
& tensorVal .~ values
|
||||
|
||||
-- | Reshape a N-D tensor down to a scalar.
|
||||
--
|
||||
--
|
||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||
|
@ -257,29 +313,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
|
|||
|
||||
-- | Create a constant vector.
|
||||
vector :: TensorType a => [a] -> Tensor Value a
|
||||
vector xs = constant [fromIntegral $ length xs] xs
|
||||
vector = vector' id
|
||||
|
||||
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
|
||||
vector' params xs = constant' params [fromIntegral $ length xs] xs
|
||||
|
||||
-- | Create a constant scalar.
|
||||
scalar :: forall a . TensorType a => a -> Tensor Value a
|
||||
scalar x = constant [] [x]
|
||||
scalar :: TensorType a => a -> Tensor Value a
|
||||
scalar = scalar' id
|
||||
|
||||
-- Random tensor from the unit normal distribution with bounded values.
|
||||
truncatedNormal :: forall a v . TensorType a
|
||||
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
|
||||
scalar' params x = constant' params [] [x]
|
||||
|
||||
-- | Random tensor from the unit normal distribution with bounded values.
|
||||
--
|
||||
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
|
||||
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||
=> Tensor v Int64 -- ^ Shape.
|
||||
-> Build (Tensor Value a)
|
||||
truncatedNormal = buildOp $ opDef "TruncatedNormal"
|
||||
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||
& opAttr "T" .~ tensorType (undefined :: Int64)
|
||||
-> m (Tensor Value a)
|
||||
truncatedNormal = CoreOps.truncatedNormal
|
||||
|
||||
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
|
||||
=> OpParams -> Tensor v Int64 -- ^ Shape.
|
||||
-> m (Tensor Value a)
|
||||
truncatedNormal' = CoreOps.truncatedNormal'
|
||||
|
||||
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
|
||||
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
|
||||
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
|
||||
|
||||
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
|
||||
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
|
||||
shape = CoreOps.shape
|
||||
|
||||
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
|
||||
shape' = CoreOps.shape'
|
||||
|
||||
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims = CoreOps.expandDims
|
||||
|
||||
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
|
||||
expandDims' = CoreOps.expandDims'
|
||||
|
||||
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
|
||||
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
|
||||
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
module Main where
|
||||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Lens.Family2 ((^.))
|
||||
import Lens.Family2 ((^.), (.~))
|
||||
import Data.List (sort)
|
||||
import Proto.Tensorflow.Core.Framework.Graph
|
||||
( node )
|
||||
|
@ -35,13 +34,12 @@ import TensorFlow.Build
|
|||
, asGraphDef
|
||||
, evalBuildT
|
||||
, flushNodeBuffer
|
||||
, hoistBuildT
|
||||
, render
|
||||
, withDevice
|
||||
, colocateWith
|
||||
, withNameScope
|
||||
, opName
|
||||
)
|
||||
import TensorFlow.ControlFlow (named)
|
||||
import TensorFlow.Types (unScalar)
|
||||
import TensorFlow.Ops
|
||||
( add
|
||||
|
@ -49,13 +47,12 @@ import TensorFlow.Ops
|
|||
, constant
|
||||
, initializedVariable
|
||||
, variable
|
||||
, variable'
|
||||
)
|
||||
import TensorFlow.Output (Device(..))
|
||||
import TensorFlow.Tensor (Tensor, Value, Ref)
|
||||
import TensorFlow.Session
|
||||
( build
|
||||
, buildAnd
|
||||
, run
|
||||
( run
|
||||
, runSession
|
||||
, run_
|
||||
)
|
||||
|
@ -65,26 +62,16 @@ import Test.HUnit ((@=?))
|
|||
import Google.Test (googleTest)
|
||||
import qualified Data.Vector as V
|
||||
|
||||
-- | Test named behavior.
|
||||
testNamed :: Test
|
||||
testNamed = testCase "testNamed" $ do
|
||||
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
|
||||
-- | Test 'opName' behavior.
|
||||
testOpName :: Test
|
||||
testOpName = testCase "testOpName" $ do
|
||||
let graph = variable' (opName .~ "foo") []
|
||||
>>= render :: Build (Tensor Ref Float)
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
"foo" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Test named deRef behavior.
|
||||
testNamedDeRef :: Test
|
||||
testNamedDeRef = testCase "testNamedDeRef" $ do
|
||||
let graph = named "foo" <$> do
|
||||
v :: Tensor Ref Float <- variable []
|
||||
assign v 5
|
||||
-- TODO: Implement TensorFlow get_variable and test it.
|
||||
runSession $ do
|
||||
out <- buildAnd run graph
|
||||
liftIO $ 5 @=? (unScalar out :: Float)
|
||||
|
||||
-- | Test that "run" will render and extend any pure ops that haven't already
|
||||
-- been rendered.
|
||||
testPureRender :: Test
|
||||
|
@ -96,7 +83,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
|
|||
testInitializedVariable :: Test
|
||||
testInitializedVariable =
|
||||
testCase "testInitializedVariable" $ runSession $ do
|
||||
(formula, reset) <- build $ do
|
||||
(formula, reset) <- do
|
||||
v <- initializedVariable 42
|
||||
r <- assign v 24
|
||||
return (1 `add` v, r)
|
||||
|
@ -109,7 +96,7 @@ testInitializedVariable =
|
|||
testInitializedVariableShape :: Test
|
||||
testInitializedVariableShape =
|
||||
testCase "testInitializedVariableShape" $ runSession $ do
|
||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||
vector <- initializedVariable (constant [1] [42 :: Float])
|
||||
result <- run vector
|
||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||
|
||||
|
@ -122,33 +109,30 @@ testNameScoped = testCase "testNameScoped" $ do
|
|||
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
|
||||
-- | Test combined named and nameScoped behavior.
|
||||
-- | Test combined opName and nameScoped behavior.
|
||||
testNamedAndScoped :: Test
|
||||
testNamedAndScoped = testCase "testNamedAndScoped" $ do
|
||||
let graph :: Build (Tensor Ref Float)
|
||||
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
|
||||
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
|
||||
>>= render
|
||||
nodeDef :: NodeDef
|
||||
nodeDef = head $ asGraphDef graph ^. node
|
||||
"RefIdentity" @=? (nodeDef ^. op)
|
||||
"Variable" @=? (nodeDef ^. op)
|
||||
"foo1/bar1" @=? (nodeDef ^. name)
|
||||
|
||||
-- | Lift a Build action into a context for HUnit to run.
|
||||
liftBuild :: Build a -> BuildT IO a
|
||||
liftBuild = hoistBuildT (return . runIdentity)
|
||||
|
||||
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
|
||||
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
|
||||
flushed field = sort . map field <$> liftBuild flushNodeBuffer
|
||||
flushed field = sort . map field <$> flushNodeBuffer
|
||||
|
||||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testRenderDedup :: Test
|
||||
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
renderNodes
|
||||
names <- flushed (^. name)
|
||||
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
|
||||
-- Render the nodes in a different scope, which should cause them
|
||||
-- to be distinct from the previous ones.
|
||||
liftBuild $ withNameScope "foo" renderNodes
|
||||
withNameScope "foo" renderNodes
|
||||
scopedNames <- flushed (^. name)
|
||||
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
|
||||
where
|
||||
|
@ -165,7 +149,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
|
|||
-- | Test the interaction of rendering, CSE and scoping.
|
||||
testDeviceColocation :: Test
|
||||
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
|
||||
liftBuild renderNodes
|
||||
renderNodes
|
||||
devices <- flushed (\x -> (x ^. name, x ^. device))
|
||||
liftIO $ [ ("Add_2","dev0")
|
||||
, ("Const_1","dev0")
|
||||
|
@ -182,8 +166,7 @@ main :: IO ()
|
|||
main = googleTest [ testInitializedVariable
|
||||
, testInitializedVariableShape
|
||||
, testDeviceColocation
|
||||
, testNamed
|
||||
, testNamedDeRef
|
||||
, testOpName
|
||||
, testNameScoped
|
||||
, testNamedAndScoped
|
||||
, testPureRender
|
||||
|
|
|
@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
|
|||
restitch = CoreOps.dynamicStitch restitchIndices splitParts
|
||||
in monadicIO $ run $ do
|
||||
fromIntegral numParts @=? length splitParts
|
||||
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
|
||||
valuesOut <- TF.runSession $ TF.run restitch
|
||||
V.fromList values @=? valuesOut
|
||||
|
||||
data StitchExample a = StitchExample Int64 [a] [Int32]
|
||||
|
|
|
@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
|
|||
import qualified TensorFlow.Types as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
|
||||
|
||||
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
|
||||
buildAndRun = TF.runSession . TF.buildAnd TF.run
|
||||
|
||||
|
||||
-- | Tries to perform a simple embedding lookup, with two partitions.
|
||||
|
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
|
|||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup embedding ids
|
||||
|
||||
(values, shape) <- buildAndRun $ do
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
return (vs, TF.shape vs)
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
|
|||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||
let op = embeddingLookup [embedding] ids
|
||||
|
||||
(values, shape) <- buildAndRun $ do
|
||||
(values, shape) <- TF.runSession $ do
|
||||
vs <- op
|
||||
return (vs, TF.shape vs)
|
||||
TF.run (vs, TF.shape vs)
|
||||
|
||||
-- This is the shape that is returned in the equiv. Python.
|
||||
shape @=? V.fromList [1, 2, 3]
|
||||
|
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
let shape = TF.Shape [2]
|
||||
|
||||
gs <- TF.runSession $ do
|
||||
grads <- TF.build $ do
|
||||
let embShape = TF.Shape [2, 1]
|
||||
let embeddingInit = [1, 20 ::Float]
|
||||
let idValues = [1, 1 :: Int32]
|
||||
|
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
|||
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
|
||||
|
||||
grad <- fmap head (TF.gradients loss [embedding])
|
||||
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
|
||||
|
||||
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
|
||||
TF.runWithFeeds
|
||||
[TF.feed x $ TF.encodeTensorData shape xVals]
|
||||
grad
|
||||
-- Gradients should be zero (or close)
|
||||
assertAllClose gs (V.fromList ([0, 0 :: Float]))
|
||||
|
||||
|
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
|
|||
shapedValues = TF.constant shape values
|
||||
in monadicIO $ run $ do
|
||||
(shapeOut, got, want :: V.Vector a) <-
|
||||
TF.runSession $ TF.buildAnd TF.run $ do
|
||||
TF.runSession $ TF.run =<< do
|
||||
embeddings <- embeddingLookup modShardedValues indicesVector
|
||||
return (TF.cast (TF.shape embeddings), embeddings, directs)
|
||||
-- Checks the explicitly documented invariant of embeddingLookup.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE NoMonomorphismRestriction #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
import Data.Int (Int32)
|
||||
|
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
|
|||
y = x*x + b
|
||||
grads = TF.gradients y [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
6 @=? TF.unScalar dx
|
||||
1 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
|
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
b = TF.scalar (4 :: Float)
|
||||
grads = TF.gradients x [x, b]
|
||||
-- Assert that the gradients are right.
|
||||
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
|
||||
[dx, db] <- TF.runSession $ grads >>= TF.run
|
||||
1 @=? TF.unScalar dx
|
||||
0 @=? TF.unScalar db
|
||||
-- Assert that the graph has the expected ops.
|
||||
|
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
|
|||
-- Test that identical "stateful" ops work with createGraph.
|
||||
testCreateGraphStateful :: Test
|
||||
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
||||
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx, dy] <- TF.runSession $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
|
||||
TF.gradients (x + y*3) [x, y]
|
||||
TF.gradients (x + y*3) [x, y] >>= TF.run
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. These asserts are extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
|
|||
-- Test that name scopes work with createGraph.
|
||||
testCreateGraphNameScopes :: Test
|
||||
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let shape = TF.constant (TF.Shape [1]) [1]
|
||||
x :: TF.Tensor TF.Value Float <-
|
||||
TF.withNameScope "foo" (TF.truncatedNormal shape)
|
||||
TF.gradients x [x]
|
||||
TF.gradients x [x] >>= TF.run
|
||||
-- If this test fails, it will likely be caused by an exception within
|
||||
-- `TF.gradients`. This assert is extra.
|
||||
1 @=? TF.unScalar dx
|
||||
|
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
|
|||
-- Test that createGraph can handle graphs with diamond shapes.
|
||||
testDiamond :: Test
|
||||
testDiamond = testCase "testDiamond" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1]
|
||||
y = x*x
|
||||
z = y*y
|
||||
TF.gradients z [x]
|
||||
TF.gradients z [x] >>= TF.run
|
||||
(4 :: Float) @=? TF.unScalar dx
|
||||
|
||||
|
||||
testMaxGradient :: Test
|
||||
testMaxGradient = testCase "testMaxGradient" $ do
|
||||
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
|
||||
[dx] <- TF.runSession $ do
|
||||
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
|
||||
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
|
||||
TF.gradients y [x]
|
||||
TF.gradients y [x] >>= TF.run
|
||||
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Main where
|
||||
|
@ -19,6 +20,7 @@ module Main where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int32, Int64)
|
||||
import Google.Test (googleTest)
|
||||
import Lens.Family2 ((.~))
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
import Test.Framework (Test)
|
||||
import Test.Framework.Providers.HUnit (testCase)
|
||||
|
@ -27,7 +29,6 @@ import qualified Data.ByteString.Char8 as B8
|
|||
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
|
@ -41,7 +42,7 @@ testSize = testCase "testSize" $ do
|
|||
TF.Scalar (2 * 3 :: Int32) @=? x
|
||||
|
||||
eval :: TF.Fetchable t a => t -> IO a
|
||||
eval = TF.runSession . TF.buildAnd TF.run . return
|
||||
eval = TF.runSession . TF.run
|
||||
|
||||
-- | Confirms that the original example from Python code works.
|
||||
testReducedShape :: Test
|
||||
|
@ -54,22 +55,48 @@ testSaveRestore :: Test
|
|||
testSaveRestore = testCase "testSaveRestore" $
|
||||
withSystemTempDirectory "" $ \dirPath -> do
|
||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
||||
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
|
||||
var = TF.render =<<
|
||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
||||
TF.zeroInitializedVariable' (TF.opName .~ "a")
|
||||
(TF.Shape [])
|
||||
TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||
v <- var
|
||||
TF.assign v 134 >>= TF.run_
|
||||
TF.save path [v] >>= TF.run_
|
||||
result <- TF.runSession $ do
|
||||
v <- TF.build var
|
||||
TF.buildAnd TF.run_ $ TF.restore path v
|
||||
v <- var
|
||||
TF.restore path v >>= TF.run_
|
||||
TF.run v
|
||||
liftIO $ TF.Scalar 134 @=? result
|
||||
|
||||
-- | Test that 'placeholder' is not CSE'd.
|
||||
testPlaceholderCse :: Test
|
||||
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
||||
p1 <- TF.placeholder []
|
||||
p2 <- TF.placeholder []
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
-- | Test that regular tensors can also be used for feeds, as long as they each
|
||||
-- have a different name.
|
||||
testScalarFeedCse :: Test
|
||||
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
||||
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
|
||||
-- The second op is identical to the first other than its name; make sure
|
||||
-- we don't aggressively CSE them together and prevent feeding them
|
||||
-- separately.
|
||||
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSaveRestore
|
||||
, testSize
|
||||
, testReducedShape
|
||||
, testPlaceholderCse
|
||||
, testScalarFeedCse
|
||||
]
|
||||
|
|
|
@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
|
|||
let x = TF.vector xData
|
||||
y = TF.vector yData
|
||||
-- Create scalar variables for slope and intercept.
|
||||
w <- TF.build (TF.initializedVariable 0)
|
||||
b <- TF.build (TF.initializedVariable 0)
|
||||
w <- TF.initializedVariable 0
|
||||
b <- TF.initializedVariable 0
|
||||
-- Define the loss function.
|
||||
let yHat = (x `TF.mul` w) `TF.add` b
|
||||
loss = TF.square (yHat `TF.sub` y)
|
||||
-- Optimize with gradient descent.
|
||||
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
|
||||
trainStep <- gradientDescent 0.001 loss [w, b]
|
||||
replicateM_ 1000 (TF.run trainStep)
|
||||
-- Return the learned parameters.
|
||||
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
|
||||
|
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
|
|||
gradientDescent :: Float
|
||||
-> TF.Tensor TF.Value Float
|
||||
-> [TF.Tensor TF.Ref Float]
|
||||
-> TF.Build TF.ControlNode
|
||||
-> TF.Session TF.ControlNode
|
||||
gradientDescent alpha loss params = do
|
||||
let applyGrad param grad =
|
||||
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
|
||||
|
|
|
@ -35,7 +35,7 @@ testTracing = do
|
|||
loggedValue <- newEmptyMVar
|
||||
TF.runSessionWithOptions
|
||||
(def & TF.sessionTracer .~ putMVar loggedValue)
|
||||
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
|
||||
(TF.run_ (TF.scalar (0 :: Float)))
|
||||
tryReadMVar loggedValue >>=
|
||||
maybe (assertFailure "Logging never happened") expectedFormat
|
||||
where expectedFormat x =
|
||||
|
|
|
@ -36,7 +36,6 @@ import qualified Data.ByteString as B
|
|||
import qualified Data.ByteString.Char8 as B8
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (select)
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
|
|
|
@ -12,67 +12,60 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
-- | Queues in TensorFlow graph. Very limited support for now.
|
||||
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
|
||||
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Lens.Family2 ((.~), (&))
|
||||
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
|
||||
import TensorFlow.BuildOp (buildOp)
|
||||
import TensorFlow.ControlFlow (group)
|
||||
import TensorFlow.Tensor (Ref, Tensor)
|
||||
import TensorFlow.Types (TensorType, tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
|
||||
import TensorFlow.Types (TensorTypes, fromTensorTypes)
|
||||
|
||||
-- | A queue carrying tuples. The underlying structure is more
|
||||
-- versatile and can be made to support arbitrary tuples.
|
||||
data Queue2 a b = Queue2 { handle :: Handle }
|
||||
-- | A queue carrying tuples.
|
||||
data Queue (as :: [*]) = Queue { handle :: Handle }
|
||||
|
||||
type Handle = Tensor Ref ByteString
|
||||
|
||||
-- | Adds the given values to the queue.
|
||||
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Tensor v1 a
|
||||
-> Tensor v2 b
|
||||
-> Build ControlNode
|
||||
enqueue q =
|
||||
buildOp (opDef "QueueEnqueue"
|
||||
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
(handle q)
|
||||
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
|
||||
=> Queue as
|
||||
-> TensorList v as
|
||||
-> m ControlNode
|
||||
enqueue = CoreOps.queueEnqueue . handle
|
||||
|
||||
-- | Retrieves the values from the queue.
|
||||
dequeue :: forall a b . (TensorType a, TensorType b)
|
||||
=> Queue2 a b
|
||||
-> Build (Tensor Ref a, Tensor Ref b)
|
||||
-- ^ Dequeued tensors. They are paired in a sense
|
||||
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||
=> Queue as
|
||||
-> m (TensorList Value as)
|
||||
-- ^ Dequeued tensors. They are coupled in a sense
|
||||
-- that values appear together, even if they are
|
||||
-- not consumed together.
|
||||
dequeue q =
|
||||
buildOp (opDef "QueueDequeue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)])
|
||||
(handle q)
|
||||
dequeue = CoreOps.queueDequeue . handle
|
||||
|
||||
-- | Creates a new queue with the given capacity and shared name.
|
||||
makeQueue2 :: forall a b . (TensorType a, TensorType b)
|
||||
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
|
||||
=> Int64 -- ^ The upper bound on the number of elements in
|
||||
-- this queue. Negative numbers mean no limit.
|
||||
-> ByteString -- ^ If non-empty, this queue will be shared
|
||||
-- under the given name across multiple sessions.
|
||||
-> Build (Queue2 a b)
|
||||
makeQueue2 capacity sharedName = do
|
||||
q <- buildOp (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ [ tensorType (undefined :: a)
|
||||
, tensorType (undefined :: b)]
|
||||
-> m (Queue as)
|
||||
makeQueue capacity sharedName = do
|
||||
q <- build $ buildOp (opDef "FIFOQueue"
|
||||
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
|
||||
& opAttr "shared_name" .~ sharedName
|
||||
& opAttr "capacity" .~ capacity
|
||||
)
|
||||
group q >>= addInitializer
|
||||
return (Queue2 q)
|
||||
return (Queue q)
|
||||
|
||||
-- TODO(gnezdo): Figure out the closing story for queues.
|
||||
|
|
|
@ -39,6 +39,7 @@ Test-Suite QueueTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-queue
|
||||
, test-framework
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
|
@ -20,13 +21,12 @@ module Main where
|
|||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Types (Scalar(..))
|
||||
import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Queue
|
||||
import TensorFlow.Session
|
||||
( asyncProdNodes
|
||||
, build
|
||||
, buildAnd
|
||||
, run
|
||||
, runSession
|
||||
, run_
|
||||
|
@ -39,42 +39,50 @@ import qualified Data.ByteString as BS
|
|||
-- | Test basic queue behaviors.
|
||||
testBasic :: Test
|
||||
testBasic = testCase "testBasic" $ runSession $ do
|
||||
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
||||
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
||||
x <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
||||
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||
run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
|
||||
x <- run =<< dequeue q
|
||||
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
|
||||
|
||||
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
||||
y <- buildAnd run (dequeue q)
|
||||
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
||||
run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
|
||||
y <- run =<< dequeue q
|
||||
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||
-- fetched. Equivalently we could write
|
||||
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||
-- or else allow the types to be determined by future use of the fetched
|
||||
-- value.
|
||||
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
|
||||
liftIO $ expected @=? y
|
||||
|
||||
-- | Test queue pumping.
|
||||
testPump :: Test
|
||||
testPump = testCase "testPump" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 31 (scalar "Baz")
|
||||
<*> enqueue q (31 :/ scalar "Baz" :/ Nil)
|
||||
-- This is a realistic use. The pump inputs are pre-bound to some
|
||||
-- nodes that produce values when pumped (e.g. read from a
|
||||
-- file).
|
||||
run_ (pump, pump)
|
||||
|
||||
(x, y) <- run (deq, deq)
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
||||
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
||||
let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
|
||||
liftIO $ expected @=? x
|
||||
liftIO $ expected @=? y
|
||||
|
||||
testAsync :: Test
|
||||
testAsync = testCase "testAsync" $ runSession $ do
|
||||
(deq, pump) <- build $ do
|
||||
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
||||
(deq, pump) <- do
|
||||
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
|
||||
(,) <$> dequeue q
|
||||
<*> enqueue q 10 (scalar "Async")
|
||||
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
|
||||
-- Pumps the queue until canceled by runSession exiting.
|
||||
asyncProdNodes pump
|
||||
-- Picks up a couple values and verifies they are as expected.
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
||||
let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testBasic
|
||||
|
|
|
@ -37,6 +37,7 @@ module TensorFlow.Build
|
|||
, renderedNodeDefs
|
||||
, BuildT
|
||||
, Build
|
||||
, MonadBuild(..)
|
||||
, addInitializer
|
||||
, hoistBuildT
|
||||
, evalBuildT
|
||||
|
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
|
|||
evalBuildT :: Monad m => BuildT m a -> m a
|
||||
evalBuildT (BuildT f) = evalStateT f initGraphState
|
||||
|
||||
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
|
||||
class Monad m => MonadBuild m where
|
||||
build :: Build a -> m a
|
||||
|
||||
instance Monad m => MonadBuild (BuildT m) where
|
||||
build = hoistBuildT $ return . runIdentity
|
||||
|
||||
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
|
||||
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
|
||||
flushNodeBuffer = do
|
||||
flushNodeBuffer :: MonadBuild m => m [NodeDef]
|
||||
flushNodeBuffer = build $ do
|
||||
ns <- use nodeBuffer
|
||||
nodeBuffer .= []
|
||||
return ns
|
||||
|
@ -229,8 +237,8 @@ flushInitializers = do
|
|||
|
||||
-- | Registers the given node to be executed before the next
|
||||
-- 'TensorFlow.Session.run'.
|
||||
addInitializer :: ControlNode -> Build ()
|
||||
addInitializer (ControlNode o) = do
|
||||
addInitializer :: MonadBuild m => ControlNode -> m ()
|
||||
addInitializer (ControlNode o) = build $ do
|
||||
i <- getOrAddOp o
|
||||
initializationNodes %= (i:)
|
||||
|
||||
|
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
|
|||
gs = snd $ runIdentity $ runBuildT b
|
||||
|
||||
-- TODO: check against existing nodes for conflicts?
|
||||
addGraphDef :: GraphDef -> Build ()
|
||||
addGraphDef g = nodeBuffer <>= g ^. node
|
||||
addGraphDef :: MonadBuild m => GraphDef -> m ()
|
||||
addGraphDef g = build $ nodeBuffer <>= g ^. node
|
||||
|
||||
-- | Render the given op if it hasn't been rendered already, and return its
|
||||
-- name.
|
||||
|
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
|
|||
|
||||
-- | Modify some part of the state, run an action, and restore the state
|
||||
-- after that action is done.
|
||||
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
|
||||
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
|
||||
withStateLens accessor f act = do
|
||||
old <- use accessor
|
||||
accessor %= f
|
||||
old <- build $ use accessor
|
||||
build $ accessor %= f
|
||||
result <- act
|
||||
accessor .= old
|
||||
build $ accessor .= old
|
||||
return result
|
||||
|
||||
-- | Set a device for all nodes rendered in the given 'Build' action
|
||||
-- (unless further overridden by another use of withDevice).
|
||||
withDevice :: Maybe Device -> Build a -> Build a
|
||||
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
|
||||
withDevice d = withStateLens defaultDevice (const d)
|
||||
|
||||
-- | Places all nodes rendered in the given 'Build' action on the same
|
||||
-- device as the given Tensor (see also 'withDevice'). Make sure that
|
||||
-- the action has side effects of rendering the desired tensors. A pure
|
||||
-- return would not have the desired effect.
|
||||
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
|
||||
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
|
||||
colocateWith t x = do
|
||||
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
|
||||
withDevice (Just d) x
|
||||
|
||||
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
|
||||
withNameScope :: Text -> Build a -> Build a
|
||||
withNameScope :: MonadBuild m => Text -> m a -> m a
|
||||
withNameScope s = withStateLens currentScope (Scope s :)
|
||||
|
||||
-- | Add control inputs to all nodes rendered in the given 'Build' action.
|
||||
withNodeDependencies :: Set NodeName -> Build a -> Build a
|
||||
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
|
||||
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
||||
|
||||
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
|
||||
|
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
|
|||
-- This operation is idempotent; @render >=> render === render@. However,
|
||||
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
|
||||
-- may result in two different 'Tensor's.
|
||||
render :: Tensor v a -> Build (Tensor v a)
|
||||
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
||||
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
|
||||
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
|
||||
|
||||
-- | Render a 'Tensor' and get its node's name.
|
||||
renderNodeName :: Tensor v a -> Build NodeName
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module TensorFlow.BuildOp
|
||||
|
@ -21,6 +23,7 @@ module TensorFlow.BuildOp
|
|||
, buildOp
|
||||
, buildListOp
|
||||
, eqLengthGuard
|
||||
, OpParams
|
||||
)
|
||||
where
|
||||
|
||||
|
@ -33,6 +36,7 @@ import Lens.Family2 ((&), (<>~), (^.))
|
|||
import TensorFlow.Build
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
data ResultState = ResultState !OutputIx [Int64] deriving Show
|
||||
|
||||
|
@ -98,6 +102,22 @@ instance OpResult (Tensor Ref a) where
|
|||
instance OpResult ControlNode where
|
||||
toResult = ControlNode <$> ask
|
||||
|
||||
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
|
||||
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
|
||||
where
|
||||
loop :: TensorTypeList bs -> Result (TensorList v bs)
|
||||
loop Nil = return Nil
|
||||
loop (TensorTypeProxy :/ ls) = do
|
||||
t <- tensorResult v
|
||||
ts <- loop ls
|
||||
return (t :/ ts)
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Value as) where
|
||||
toResult = tensorListResult ValueKind
|
||||
|
||||
instance TensorTypes as => OpResult (TensorList Ref as) where
|
||||
toResult = tensorListResult RefKind
|
||||
|
||||
instance OpResult a => OpResult [a] where
|
||||
toResult = do
|
||||
ResultState i ns <- get
|
||||
|
@ -159,6 +179,12 @@ instance BuildOp (Tensor Value a) where
|
|||
instance BuildOp (Tensor Ref a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Value as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance TensorTypes as => BuildOp (TensorList Ref as) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp [Tensor Value a] where
|
||||
buildOp' = pureResult
|
||||
|
||||
|
@ -199,6 +225,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
|
|||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
|
||||
|
||||
instance BuildOp f => BuildOp (TensorList v as -> f) where
|
||||
buildOp' rf o accum ts
|
||||
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
|
||||
|
||||
-- | Returns true if all the integers in each tuple are identical.
|
||||
-- Throws an error with a descriptive message if not.
|
||||
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
|
||||
|
@ -209,3 +239,7 @@ eqLengthGuard = all eachOk
|
|||
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
|
||||
error ("number_attr " ++ numberAttrName ++
|
||||
" contains tensors with different length " ++ show pairs)
|
||||
|
||||
-- | Parameters to build an op (for example, the node name or optional attributes).
|
||||
-- TODO: be more type safe.
|
||||
type OpParams = OpDef -> OpDef
|
||||
|
|
|
@ -22,27 +22,21 @@ module TensorFlow.ControlFlow
|
|||
withControlDependencies
|
||||
, group
|
||||
-- * Operations
|
||||
, identity
|
||||
, noOp
|
||||
, named
|
||||
) where
|
||||
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text (Text)
|
||||
import Lens.Family2 ((&), (^.), (.~))
|
||||
import Lens.Family2 ((&), (.~))
|
||||
|
||||
import TensorFlow.BuildOp
|
||||
import TensorFlow.Build
|
||||
import TensorFlow.Nodes
|
||||
import TensorFlow.Output
|
||||
import TensorFlow.Tensor
|
||||
import TensorFlow.Types
|
||||
|
||||
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
|
||||
-- on the nodes in the first argument.
|
||||
withControlDependencies :: Nodes t => t -> Build a -> Build a
|
||||
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
|
||||
withControlDependencies deps act = do
|
||||
nodes <- getNodes deps
|
||||
nodes <- build $ getNodes deps
|
||||
withNodeDependencies nodes act
|
||||
|
||||
-- TODO(judahjacobson): Reimplement withDependencies.
|
||||
|
@ -51,37 +45,12 @@ withControlDependencies deps act = do
|
|||
--
|
||||
-- When this op finishes, all ops in the input @n@ have finished. This op has
|
||||
-- no output.
|
||||
group :: Nodes t => t -> Build ControlNode
|
||||
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
|
||||
group deps = do
|
||||
nodes <- Set.toList <$> getNodes deps
|
||||
nodes <- build $ Set.toList <$> getNodes deps
|
||||
-- TODO: slicker way
|
||||
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
|
||||
|
||||
|
||||
-- | Returns a 'Tensor' with the same shape and contents as the input.
|
||||
identity :: TensorType a => Tensor v a -> Tensor v a
|
||||
identity = namedIdentity implicitName
|
||||
|
||||
-- | Returns a 'Tensor' with a given name and the same shape and contents as
|
||||
-- the input.
|
||||
--
|
||||
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
|
||||
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
|
||||
-- whether we can change that op.
|
||||
named :: TensorType a => Text -> Tensor v a -> Tensor v a
|
||||
named = namedIdentity . explicitName
|
||||
|
||||
-- | An internal version of "identity" that allows setting the name
|
||||
-- of the output Tensor.
|
||||
namedIdentity :: forall a v . TensorType a
|
||||
=> PendingNodeName -> Tensor v a -> Tensor v a
|
||||
namedIdentity n t = case t ^. tensorKind of
|
||||
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
|
||||
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
|
||||
where
|
||||
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
|
||||
|
||||
|
||||
-- | Does nothing. Only useful as a placeholder for control edges.
|
||||
noOp :: ControlNode
|
||||
noOp = buildOp $ opDef "NoOp"
|
||||
|
|
|
@ -31,8 +31,7 @@ module TensorFlow.Core
|
|||
, runSession
|
||||
, runSessionWithOptions
|
||||
-- ** Building graphs
|
||||
, build
|
||||
, buildAnd
|
||||
, MonadBuild(..)
|
||||
-- ** Running graphs
|
||||
, Fetchable
|
||||
, Nodes
|
||||
|
@ -51,14 +50,14 @@ module TensorFlow.Core
|
|||
, render
|
||||
, asGraphDef
|
||||
, addGraphDef
|
||||
|
||||
, opName
|
||||
, opAttr
|
||||
-- * Tensor
|
||||
, ControlNode
|
||||
, Tensor
|
||||
, Value
|
||||
, Ref
|
||||
, TensorKind(..)
|
||||
, tensorAttr
|
||||
, value
|
||||
, tensorFromName
|
||||
-- ** Element types
|
||||
|
@ -75,12 +74,10 @@ module TensorFlow.Core
|
|||
, Device(..)
|
||||
, withDevice
|
||||
, withNameScope
|
||||
, named
|
||||
-- ** Dependencies
|
||||
, withControlDependencies
|
||||
, group
|
||||
-- ** Misc
|
||||
, identity
|
||||
, noOp
|
||||
) where
|
||||
|
||||
|
|
|
@ -12,15 +12,18 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
module TensorFlow.Nodes where
|
||||
|
||||
import Control.Applicative (liftA2, liftA3)
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Set (Set)
|
||||
|
@ -96,6 +99,19 @@ instance Nodes ControlNode where
|
|||
instance a ~ () => Fetchable ControlNode a where
|
||||
getFetch _ = return $ pure ()
|
||||
|
||||
instance Nodes (ListOf f '[]) where
|
||||
getNodes _ = return Set.empty
|
||||
|
||||
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
|
||||
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
|
||||
|
||||
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
|
||||
getFetch _ = return $ pure Nil
|
||||
|
||||
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
|
||||
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
|
||||
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
|
||||
|
||||
instance Nodes (Tensor v a) where
|
||||
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)
|
||||
|
||||
|
|
|
@ -124,6 +124,9 @@ data OpDef = OpDef
|
|||
data PendingNodeName = ExplicitName !Text | ImplicitName
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString PendingNodeName where
|
||||
fromString = ExplicitName . fromString
|
||||
|
||||
-- | The name of a node in the graph. This corresponds to the proto field
|
||||
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
||||
-- (if the node was implicitly named).
|
||||
|
|
|
@ -26,8 +26,7 @@ module TensorFlow.Session (
|
|||
sessionTracer,
|
||||
runSession,
|
||||
runSessionWithOptions,
|
||||
build,
|
||||
buildAnd,
|
||||
MonadBuild(..),
|
||||
extend,
|
||||
addGraphDef,
|
||||
run,
|
||||
|
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
|
|||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (Default, def)
|
||||
import Data.Functor.Identity (runIdentity)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.ProtoLens (showMessage)
|
||||
import Data.Set (Set)
|
||||
|
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
|
|||
FFI.setSessionTarget (options ^. sessionTarget) opt
|
||||
FFI.setSessionConfig (options ^. sessionConfig) opt
|
||||
|
||||
-- | Lift a 'Build' action into a 'Session', including any explicit op
|
||||
-- renderings.
|
||||
build :: Build a -> Session a
|
||||
build = Session . lift . hoistBuildT (return . runIdentity)
|
||||
instance MonadBuild Session where
|
||||
build = Session . lift . build
|
||||
|
||||
-- | Add all pending rendered nodes to the TensorFlow graph and runs
|
||||
-- any pending initializers.
|
||||
|
@ -147,13 +143,6 @@ extend = do
|
|||
unless (null initializers) $
|
||||
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
|
||||
|
||||
-- | Helper combinator for doing something with the result of a 'Build' action.
|
||||
-- Example usage:
|
||||
--
|
||||
-- > buildAnd run :: Fetchable t a => Build t -> Session a
|
||||
buildAnd :: (a -> Session b) -> Build a -> Session b
|
||||
buildAnd f m = build m >>= f
|
||||
|
||||
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
|
||||
-- rendered, and fetch the corresponding values for 'a'.
|
||||
run :: Fetchable t a => t -> Session a
|
||||
|
|
|
@ -12,20 +12,29 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE Rank2Types #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
module TensorFlow.Tensor where
|
||||
|
||||
import Data.String (IsString(..))
|
||||
import qualified Data.Text as Text
|
||||
import Lens.Family2 (Lens', Traversal')
|
||||
import Lens.Family2 (Lens', (^.))
|
||||
import Lens.Family2.Unchecked (lens)
|
||||
|
||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
||||
import TensorFlow.Types (TensorData(..), Attribute)
|
||||
import TensorFlow.Output (Output)
|
||||
import TensorFlow.Types
|
||||
( TensorData(..)
|
||||
, ListOf(..)
|
||||
)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
||||
-- | A named output of a TensorFlow operation.
|
||||
|
@ -52,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
|
|||
tensorOutput :: Lens' (Tensor v a) Output
|
||||
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||
|
||||
-- TODO: Come up with a better API for handling attributes.
|
||||
-- | Lens for the attributes of a tensor.
|
||||
--
|
||||
-- Only valid if the tensor has not yet been rendered. If the tensor has been
|
||||
-- rendered, the traversal will be over nothing (nothing can be read or
|
||||
-- written).
|
||||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||
|
||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||
-- Ref into Value. This behaves like a no-op.
|
||||
value :: Tensor v a -> Tensor Value a
|
||||
|
@ -83,3 +83,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
|
|||
-- TODO(judahjacobson): add more safety checks here.
|
||||
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
|
||||
tensorFromName v = Tensor v . fromString . Text.unpack
|
||||
|
||||
type TensorList v = ListOf (Tensor v)
|
||||
|
||||
tensorListOutputs :: TensorList v as -> [Output]
|
||||
tensorListOutputs Nil = []
|
||||
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts
|
||||
|
|
|
@ -13,9 +13,11 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
@ -36,23 +38,35 @@ module TensorFlow.Types
|
|||
, Shape(..)
|
||||
, protoShape
|
||||
, Attribute(..)
|
||||
, DataType(..)
|
||||
-- * Lists
|
||||
, ListOf(..)
|
||||
, List
|
||||
, (/:/)
|
||||
, TensorTypeProxy(..)
|
||||
, TensorTypes(..)
|
||||
, TensorTypeList
|
||||
, fromTensorTypeList
|
||||
, fromTensorTypes
|
||||
-- * Type constraints
|
||||
, OneOf
|
||||
, type (/=)
|
||||
, OneOfs
|
||||
-- ** Implementation of constraints
|
||||
, TypeError
|
||||
, ExcludedCase
|
||||
, TensorTypes
|
||||
, NoneOf
|
||||
, type (\\)
|
||||
, Delete
|
||||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Data.String (IsString)
|
||||
import Data.Word (Word8, Word16, Word64)
|
||||
import Foreign.Storable (Storable)
|
||||
|
@ -376,6 +390,44 @@ instance Attribute [DataType] where
|
|||
instance Attribute [Int64] where
|
||||
attrLens = list . i
|
||||
|
||||
-- | A heterogeneous list type.
|
||||
data ListOf f as where
|
||||
Nil :: ListOf f '[]
|
||||
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
|
||||
|
||||
infixr 5 :/
|
||||
|
||||
type family All f as :: Constraint where
|
||||
All f '[] = ()
|
||||
All f (a ': as) = (f a, All f as)
|
||||
|
||||
type family Map f as where
|
||||
Map f '[] = '[]
|
||||
Map f (a ': as) = f a ': Map f as
|
||||
|
||||
instance All Eq (Map f as) => Eq (ListOf f as) where
|
||||
Nil == Nil = True
|
||||
(x :/ xs) == (y :/ ys) = x == y && xs == ys
|
||||
-- Newer versions of GHC use the GADT to tell that the previous cases are
|
||||
-- exhaustive.
|
||||
#if _GLASGOW_HASKELL__ < 800
|
||||
_ == _ = False
|
||||
#endif
|
||||
|
||||
instance All Show (Map f as) => Show (ListOf f as) where
|
||||
showsPrec _ Nil = showString "Nil"
|
||||
showsPrec d (x :/ xs) = showParen (d > 10)
|
||||
$ showsPrec 6 x . showString " :/ "
|
||||
. showsPrec 6 xs
|
||||
|
||||
type List = ListOf Identity
|
||||
|
||||
-- | Equivalent of ':/' for lists.
|
||||
(/:/) :: a -> List as -> List (a ': as)
|
||||
(/:/) = (:/) . Identity
|
||||
|
||||
infixr 5 /:/
|
||||
|
||||
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||
--
|
||||
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||
|
@ -393,13 +445,38 @@ instance Attribute [Int64] where
|
|||
--
|
||||
-- using an enumeration of all the possible 'TensorType's.
|
||||
type OneOf ts a
|
||||
-- Assert `TensorTypes ts` to make error messages a little better.
|
||||
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
|
||||
|
||||
-- | A check that the input is a list of 'TensorType's.
|
||||
-- Helps improve error messages when using 'OneOf'.
|
||||
type OneOfs ts as = (TensorTypes as, TensorTypes ts,
|
||||
NoneOfs (AllTensorTypes \\ ts) as)
|
||||
|
||||
type family NoneOfs ts as :: Constraint where
|
||||
NoneOfs ts '[] = ()
|
||||
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
|
||||
|
||||
data TensorTypeProxy a where
|
||||
TensorTypeProxy :: TensorType a => TensorTypeProxy a
|
||||
|
||||
type TensorTypeList = ListOf TensorTypeProxy
|
||||
|
||||
fromTensorTypeList :: TensorTypeList ts -> [DataType]
|
||||
fromTensorTypeList Nil = []
|
||||
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
|
||||
= tensorType (undefined :: t) : fromTensorTypeList ts
|
||||
|
||||
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
|
||||
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
|
||||
|
||||
class TensorTypes (ts :: [*]) where
|
||||
instance TensorTypes '[]
|
||||
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
|
||||
tensorTypes :: TensorTypeList ts
|
||||
|
||||
instance TensorTypes '[] where
|
||||
tensorTypes = Nil
|
||||
|
||||
-- | A constraint that the input is a list of 'TensorTypes'.
|
||||
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
|
||||
tensorTypes = TensorTypeProxy :/ tensorTypes
|
||||
|
||||
-- | A constraint checking that two types are different.
|
||||
type family a /= b :: Constraint where
|
||||
|
|
Loading…
Reference in New Issue
Block a user