1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Merge branch 'master' of https://github.com/tensorflow/haskell into webarchive

This commit is contained in:
avctrh 2017-04-02 00:10:40 -04:00
commit ce4902e8ac
30 changed files with 624 additions and 435 deletions

View File

@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0)
b <- TF.build (TF.initializedVariable 0)
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
(prettyLazyText 80 $ docOpList flags x)
blackList =
-- A few data flow ops take a list of heterogeneous
-- parameters which we don't support in general form.
[ "HashTable"
, "MutableDenseHashTable"
, "MutableHashTable"
, "MutableHashTableOfTensors"
, "QueueDequeue"
, "QueueDequeueMany"
, "QueueDequeueUpTo"
, "Stack"
, "TensorArray"
, "TensorArrayV2"
, "QueueEnqueueManyV2"
, "QueueDequeueV2"
, "QueueDequeueUpToV2"
, "QueueEnqueueV2"
, "QueueDequeueManyV2"
, "Stage"
, "Unstage"
-- These should be possible to support by adding a bunch of
-- overloads with a variable number of tuple arguments.
, "Assert"
, "BarrierTakeMany"
, "Print"
, "QueueEnqueue"
, "QueueEnqueueMany"
-- Need list of types support.
, "DecodeCSV"
, "ParseExample"
, "ParseSingleSequenceExample"
, "RestoreV2"
, "Save"
, "SaveV2"
, "SaveSlices"
, "SymbolicGradient"
, "_ArrayToList"
, "_ListToArray"
[ -- Requires the "func" type:
"SymbolicGradient"
-- Easy: support larger result tuples.
, "ParseSingleSequenceExample"
, "Skipgram"
]

View File

@ -49,7 +49,7 @@ import TensorFlow.Tensor
)
import TensorFlow.Ops
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd)
(runSession, run, run_, runWithFeeds, build)
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -108,7 +108,7 @@ testGraphDefExec :: Test
testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
build $ addGraphDef graphDef
addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x
@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
run_ =<< (sequence :: Monad m => [m a] -> m [a])
[ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias
]

View File

@ -23,7 +23,7 @@ module TensorFlow.NN
import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( Build
import TensorFlow.Build ( MonadBuild
, render
, withNameScope
)
@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
--
-- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a)
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
=> Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a)
-> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits
targets' <- render targets

View File

@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF
@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do
assertAllClose (head r) (V.fromList [0.5, -0.5])
run :: TF.Fetchable t a => TF.Build t -> IO a
run = TF.runSession . TF.buildAnd TF.run
run :: TF.Fetchable t a => TF.Session t -> IO a
run = TF.runSession . (>>= TF.run)
main :: IO ()
main = googleTest [ testGradientAtZero

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
@ -147,6 +148,7 @@ imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Proxy (Proxy(Proxy))"
, "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
@ -171,18 +173,28 @@ renderQuotedTFName = dquotes . renderTFName
renderOp :: ParsedOp -> Doc
renderOp pOp = stack $
[ haddocks
, n <+> "::" <+> hang 0 (typeSig pOp)
, n <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
-- Prevent unreasonably long compilation times on ghc-7.10, due
-- to stack calling "-dump-hi" which (unnecessarily) includes the
-- inlining information, and is large for ops with many arguments.
#if __GLASGOW_HASKELL__ < 800
, "{-# NOINLINE " <> n <> "#-}"
#endif
, n <+> "::" <+> hang 0 (typeSig empty pOp)
, n <+> "=" <+> n <> "' id"
, n' <+> "::" <+> hang 0 (typeSig "OpParams ->" pOp)
, n' <+> hang 0 args <+> "|" <+> funcGuard listSizeAttrs
<+> "=" </> -- args are indented
-- the body needs to be indented wrt the name
indent indentation (functionBody pOp)
] ++ whereClause listSizeAttrs
where
n = renderHaskellName $ parsedOpName pOp
n' = n <> "'"
listSizeAttrs = inferredListSizeAttrs pOp
args = sep $ map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp)
args = sep $ "op'options"
: (map renderHaskellName
$ map attrName (explicitInputAttrs pOp)
++ map parsedArgName (parsedInputs pOp))
haddocks = "-- |" <+> multilineComment (parsedOpSummary pOp) (parsedOpDescription pOp)
-- | A check that all lists of the given size have the given length.
@ -210,15 +222,21 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
defineLengthAttr a = renderHaskellAttrName a <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
maybeLift
| parsedOpIsMonadic pOp = "build $"
| otherwise = ""
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+>
@ -229,9 +247,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
<- parsedOutputs pOp]
buildOpParts =
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments.
[ "& opAttr" <+> renderQuotedTFName n <+>
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
-- Renders type parameter arguments.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
| a <- inferredTypeAttrs pOp, let n = attrName a
] ++
-- Renders mandatory attributes as function parameters.
@ -241,9 +258,17 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
-- Renders sizes of tensor list types having number_attr.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> renderHaskellName n
| a <- inferredListSizeAttrs pOp, let n = attrName a
]
] ++
["& op'options"]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
<> ")"
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
<> ")"
-- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging.
@ -258,23 +283,28 @@ extras d = enclose "{-\n" "\n-}" $
-- | The type signature for an op.
-- Of the form:
-- forall t1 t2 v1 v2 . (TensorType t1, TensorType t2)
-- => Float -> Tensor t1 v1 -> Tensor t2 v2
-- => {pre} Float -> Tensor t1 v1 -> Tensor t2 v2
-- where "Float" is an explicit input attribute, "Tensor t1 v1" is an input, and
-- "Tensor t2 v2" is an output.
typeSig :: ParsedOp -> Doc
typeSig pOp = constraints
<+/> signatureFold (map attrInput (explicitInputAttrs pOp)
typeSig :: Doc -> ParsedOp -> Doc
typeSig pre pOp = constraints
<+/> pre </> signatureFold (map attrInput (explicitInputAttrs pOp)
++ map tensorArgAndComment (parsedInputs pOp)
++ [outputs])
where
constraints
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint
$ inferredTypeAttrs pOp
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name.
monadConstraint
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
| otherwise = []
classConstraints = monadConstraint ++ map tensorArgConstraint
(inferredTypeAttrs pOp)
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
@ -295,7 +325,7 @@ typeSig pOp = constraints
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o
| parsedOpIsMonadic pOp = "m'" <+> parens o
| otherwise = o
-- | Render an op input or output.
@ -305,17 +335,18 @@ tensorArg p = case parsedArgCase p of
ResourceArg -> "ResourceHandle"
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> kind k <+> renderHaskellName t
where
kind k = case k of
ArgTensorRef -> "Ref"
ArgTensorValue -> "Value"
ArgTensorEither v' -> strictText v'
tensorType t k = let
v = case k of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n
in v <+> a
in "Tensor" <+> kind k <+> a
attrComment :: Attr a -> Doc
attrComment a = argComment' (attrName a) (attrDescription a)
@ -347,18 +378,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
]
-- | Constraints for a given type parameter.
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
tensorArgConstraint :: Attr [DataType] -> [Doc]
tensorArgConstraint a
= ("TensorType" <+> n
: if null typeList
then []
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
-- or "TensorTypes ts" or "OneOfs [..] ts".
tensorArgConstraint :: Attr TypeParam -> Doc
tensorArgConstraint a = case attrInfo a of
TypeParam False Nothing -> "TensorType" <+> n
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
TypeParam True Nothing -> "TensorTypes" <+> n
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
where
n = renderHaskellName $ attrName a
typeList = map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell $ attrInfo a
n = renderHaskellAttrName a
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
typeList = ("'" <>) . brackets . commasep . map strictText .
Set.toList . Set.fromList .
map dtTypeToHaskell . toList
-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.

View File

@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
, Attr(..)
, AttrType(..)
, AttrBaseType(..)
, TypeParam(..)
, ParsedArg(..)
, ParsedArgCase(..)
, ArgType(..)
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
, explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr [DataType]]
, inferredTypeAttrs :: [Attr TypeParam]
-- ^ Attributes that are type parameters.
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
-- If this list is empty, then any type is acceptable.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
deriving Eq
data TypeParam = TypeParam
{ typeParamIsList :: Bool
, typeParamRestrictions :: Maybe (NonEmpty DataType)
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
-- If 'Nothing', then any type is acceptable.
}
-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
{ parsedArgName :: Name
@ -120,7 +126,6 @@ data ParsedArgCase
}
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
@ -223,11 +228,6 @@ parseOp o = ParsedOp
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
-- Type attributes that can be inferred from at least one input or output.
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
$ parsedInputs ++ parsedOutputs
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
@ -235,10 +235,14 @@ parseOp o = ParsedOp
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- Attributes that can't be inferred and don't have defaults, so must be passed
-- as separate arguments to the op.
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
argTypeParams = Set.fromList $ map tfName $
mapMaybe (getArgTypeParam . parsedArgCase) $
parsedInputs ++ parsedOutputs
-- Attributes that can't be inferred and don't have defaults, so must be
-- passed as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs.
@ -252,29 +256,30 @@ outputTensorKind a
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr implicitAttrs a
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
, t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
++ [AttrList AttrType] = Just t
| otherwise = Nothing
-- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of
ResourceArg -> Nothing
SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr argTypeParams a
| TFName (a ^. name) `notElem` argTypeParams = Nothing
| a ^. type' == "type" = Just $ TypeParam False allowed
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
| otherwise = Nothing
where
fromArgType (ArgTypeAttr n) = Just $ tfName n
fromArgType _ = Nothing
allowed = nonEmpty (a ^. allowedValues . list . type')
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
| otherwise = Nothing
getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
getArgTypeParam _ = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
@ -285,7 +290,7 @@ getInferredListSizeAttr inputs a
} <- inputs
, TFName (a ^. name) == tfName n]
| otherwise = Nothing
-- | Like mapMaybe, but associates the attribute name/description with the given info.
mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a]
mapMaybeAttrs f = mapMaybe $ \a -> do
@ -295,7 +300,7 @@ mapMaybeAttrs f = mapMaybe $ \a -> do
, attrDescription = a ^. description
, attrInfo = x
}
parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg
parseArg a tKind = ParsedArg
{ parsedArgName = makeName (a ^. name)
@ -317,15 +322,15 @@ parseArgCase a tKind
maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t
parseAttrType :: Text -> AttrType
parseAttrType = \case
parseAttrType :: OpDef -> Text -> AttrType
parseAttrType o = \case
"string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
"bool" -> AttrSingle AttrBool
"type" -> AttrSingle AttrType
"shape" -> AttrSingle AttrShape
"tensor" -> AttrSingle AttrTensor
"list(string)" -> AttrList AttrBytes
"list(int)" -> AttrList AttrInt64
"list(float)" -> AttrList AttrFloat
@ -334,3 +339,4 @@ parseAttrType = \case
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name)

View File

@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where
import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Build (MonadBuild, colocateWith, render)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v .
( TensorType a
embeddingLookup :: forall a b v m .
( MonadBuild m
, TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
-- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31
-- entries.
-> Build (Tensor Value a)
-> m (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do

View File

@ -56,7 +56,9 @@ import qualified Data.Text as Text
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( Build
( MonadBuild
, Build
, build
, render
, renderNodeName
, renderedNodeDefs
@ -70,6 +72,7 @@ import TensorFlow.Ops
, expandDims
, fill
, matMul
, matMul'
, reducedShape
, reluGrad
, reshape
@ -93,7 +96,6 @@ import TensorFlow.Tensor
, TensorKind (ValueKind)
, Value
, tensorOutput
, tensorAttr
)
import TensorFlow.Types (Attribute, OneOf, TensorType, attrLens)
import Proto.Tensorflow.Core.Framework.NodeDef
@ -111,16 +113,17 @@ type GradientCompatible a =
-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> Build [Tensor Value a]
gradients y xs = do
-> m [Tensor Value a]
gradients y xs = build $ do
-- The gradients are computed using "reverse accumulation", similarly to
-- what is described here:
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
@ -529,20 +532,20 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
let transposeA = lookupAttr nodeDef "transpose_a"
transposeB = lookupAttr nodeDef "transpose_b"
transAttrs a b =
(tensorAttr "transpose_a" .~ a) . (tensorAttr "transpose_b" .~ b)
(opAttr "transpose_a" .~ a) . (opAttr "transpose_b" .~ b)
in case (transposeA, transposeB) of
(False, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ (x `matMul` dz) & transAttrs True False ]
[ Just $ matMul' (transAttrs False True) dz y
, Just $ matMul' (transAttrs True False) x dz]
(False, True) ->
[ Just $ dz `matMul` y
, Just $ (x `matMul` dz) & transAttrs True False ]
[ Just $ matMul dz y
, Just $ matMul' (transAttrs True False) x dz]
(True, False) ->
[ Just $ (dz `matMul` y) & transAttrs False True
, Just $ x `matMul` dz ]
[ Just $ matMul' (transAttrs False True) dz y
, Just $ matMul x dz]
(True, True) ->
[ Just $ (dz `matMul` y) & transAttrs True True
, Just $ (x `matMul` dz) & transAttrs True True ]
[ Just $ matMul' (transAttrs True True) dz y
, Just $ matMul' (transAttrs True True) x dz]
opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
@ -551,16 +554,18 @@ opGrad "Transpose" _ [_, toT -> p] [dz] =
]
opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ CoreOps.conv2DBackpropInput (shape x) y dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
, Just $ CoreOps.conv2DBackpropFilter x (shape y) dz
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu
& tensorAttr "data_format" .~ dataFormat
[ Just $ CoreOps.conv2DBackpropInput'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
(shape x) y dz
, Just $ CoreOps.conv2DBackpropFilter'
((opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "use_cudnn_on_gpu" .~ useCudnnOnGpu)
. (opAttr "data_format" .~ dataFormat))
x (shape y) dz
]
where
strides = lookupAttr nodeDef "strides" :: [Int64]
@ -569,11 +574,12 @@ opGrad "Conv2D" nodeDef [toT -> x, toT -> y] [dz] =
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
[ Just $ CoreOps.maxPoolGrad x output dz
& tensorAttr "ksize" .~ ksize
& tensorAttr "strides" .~ strides
& tensorAttr "padding" .~ padding
& tensorAttr "data_format" .~ dataFormat
[ Just $ CoreOps.maxPoolGrad'
((opAttr "ksize" .~ ksize)
. (opAttr "strides" .~ strides)
. (opAttr "padding" .~ padding)
. (opAttr "data_format" .~ dataFormat))
x output dz
]
where
output :: Tensor Value a

View File

@ -58,56 +58,99 @@
module TensorFlow.Ops
( CoreOps.add
, CoreOps.add'
, CoreOps.abs
, CoreOps.abs'
, CoreOps.addN
, CoreOps.addN'
, CoreOps.argMax
, CoreOps.argMax'
, CoreOps.assign
, CoreOps.assign'
, CoreOps.broadcastGradientArgs
, CoreOps.broadcastGradientArgs'
, CoreOps.cast
, CoreOps.cast'
, CoreOps.concat
, CoreOps.concat'
, constant
, constant'
, CoreOps.equal
, CoreOps.equal'
, expandDims
, expandDims'
, initializedVariable
, initializedVariable'
, zeroInitializedVariable
, zeroInitializedVariable'
, CoreOps.fill
, CoreOps.oneHot
, CoreOps.fill'
, CoreOps.identity
, CoreOps.identity'
, CoreOps.matMul
, CoreOps.matMul'
, matTranspose
, matTranspose'
, CoreOps.mean
, CoreOps.mean'
, CoreOps.mul
, CoreOps.mul'
, CoreOps.neg
, CoreOps.neg'
, CoreOps.oneHot
, CoreOps.oneHot'
, CoreOps.pack
, CoreOps.pack'
, placeholder
, placeholder'
, CoreOps.range
, CoreOps.range'
, reducedShape
, CoreOps.relu
, CoreOps.relu'
, CoreOps.reluGrad
, CoreOps.reluGrad'
, CoreOps.reshape
, CoreOps.reshape'
, restore
, restoreFromName
, save
, scalar
, scalar'
, shape
, shape'
, CoreOps.sign
, CoreOps.sign'
, CoreOps.size
, CoreOps.size'
, CoreOps.softmax
, CoreOps.softmax'
, CoreOps.softmaxCrossEntropyWithLogits
, CoreOps.softmaxCrossEntropyWithLogits'
, CoreOps.sparseToDense
, CoreOps.sparseToDense'
, CoreOps.sub
, CoreOps.sub'
, CoreOps.sum
, CoreOps.sum'
, CoreOps.transpose
, CoreOps.transpose'
, truncatedNormal
, truncatedNormal'
, CoreOps.variable
, CoreOps.variable'
, vector
, vector'
, zeros
, CoreOps.zerosLike
, CoreOps.zerosLike'
, scalarize
) where
import Data.ByteString (ByteString)
import Data.Complex (Complex)
import Data.Int (Int32, Int64)
import Data.Word (Word16)
import Prelude hiding (abs, sum, concat)
import Data.ProtoLens (def)
import Data.Text.Encoding (encodeUtf8)
@ -151,60 +194,73 @@ instance ( TensorType a
signum = CoreOps.sign
negate = CoreOps.neg
matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])
matTranspose :: TensorType a => Tensor v a -> Tensor Value a
matTranspose = matTranspose' id
placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'
matTranspose' :: TensorType a => OpParams -> Tensor v a -> Tensor Value a
matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder = placeholder' id
placeholder' :: forall m a . (MonadBuild m, TensorType a)
=> OpParams -> Shape -> m (Tensor Value a)
placeholder' params pShape
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
-- and thus would be CSE'd.
= build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ pShape
& params
-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable initializer = do
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
)
v initializer
initializedVariable :: (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a)
initializedVariable = initializedVariable' id
initializedVariable' :: (MonadBuild m, TensorType a)
=> OpParams -> Tensor Value a -> m (Tensor Ref a)
initializedVariable' params initializer = do
v <- CoreOps.variable' params [] -- The shape is not known initially.
i <- CoreOps.assign' (opAttr "validate_shape" .~ False) v
initializer
addInitializer =<< group i
return v
-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = zeroInitializedVariable' id
zeroInitializedVariable'
:: (MonadBuild m, TensorType a, Num a) =>
OpParams -> TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable' params = initializedVariable' params . zeros
-- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a
save :: forall a m v . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> Build ControlNode
-> m ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs
build $ saveOp (scalar path) (CoreOps.pack names) xs
-- | Restore a tensor's value from a checkpoint file.
--
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a . TensorType a
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
@ -212,12 +268,12 @@ restoreFromName path name x = do
(restoreOp (scalar path) (scalar name) :: Tensor Value a)
-- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
restoreFromName path name x
-- | Create a constant tensor.
@ -227,27 +283,27 @@ restore path x = do
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
constant :: forall a . TensorType a => Shape -> [a] -> Tensor Value a
constant (Shape shape') values
constant :: TensorType a => Shape -> [a] -> Tensor Value a
constant = constant' id
constant' :: forall a . TensorType a => OpParams -> Shape -> [a] -> Tensor Value a
constant' params (Shape cShape) values
| invalidLength = error invalidLengthMsg
| otherwise = buildOp $ opDef "Const"
& opAttr "value" .~ typedNode
& opAttr "dtype" .~ nodeType
| otherwise = CoreOps.const' (params . (opAttr "value" .~ typedNode))
where
invalidLength = product shape' /= fromIntegral (length values)
invalidLength = product cShape /= fromIntegral (length values)
invalidLengthMsg = printf "invalid tensor length: expected %d got %d"
(product shape')
(product cShape)
(length values)
nodeType = tensorType (undefined :: a)
typedNode :: TensorProto
typedNode = def
& dtype .~ nodeType
& dtype .~ tensorType (undefined :: a)
& tensorShape.TensorShape.dim .~
[def & TensorShape.size .~ x | x <- shape']
[def & TensorShape.size .~ x | x <- cShape]
& tensorVal .~ values
-- | Reshape a N-D tensor down to a scalar.
--
--
-- See `TensorFlow.GenOps.Core.reshape`.
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
scalarize t = CoreOps.reshape t (vector scalarShape)
@ -257,29 +313,46 @@ scalarize t = CoreOps.reshape t (vector scalarShape)
-- | Create a constant vector.
vector :: TensorType a => [a] -> Tensor Value a
vector xs = constant [fromIntegral $ length xs] xs
vector = vector' id
vector' :: TensorType a => OpParams -> [a] -> Tensor Value a
vector' params xs = constant' params [fromIntegral $ length xs] xs
-- | Create a constant scalar.
scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]
scalar :: TensorType a => a -> Tensor Value a
scalar = scalar' id
-- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a v . TensorType a
scalar' :: TensorType a => OpParams -> a -> Tensor Value a
scalar' params x = constant' params [] [x]
-- | Random tensor from the unit normal distribution with bounded values.
--
-- This is a type-restricted version of 'TensorFlow.GenOps.Core.truncatedNormal'.
truncatedNormal :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> Tensor v Int64 -- ^ Shape.
-> Build (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
-> m (Tensor Value a)
truncatedNormal = CoreOps.truncatedNormal
truncatedNormal' :: (MonadBuild m, OneOf '[Word16, Double, Float] a)
=> OpParams -> Tensor v Int64 -- ^ Shape.
-> m (Tensor Value a)
truncatedNormal' = CoreOps.truncatedNormal'
zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
zeros (Shape s) = CoreOps.fill (vector $ map fromIntegral s) (scalar 0)
shape :: (TensorType t) => Tensor v1 t -> Tensor Value Int32
shape :: TensorType t => Tensor v1 t -> Tensor Value Int32
shape = CoreOps.shape
expandDims :: (TensorType t) => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
shape' :: TensorType t => OpParams -> Tensor v1 t -> Tensor Value Int32
shape' = CoreOps.shape'
expandDims :: TensorType t => Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims = CoreOps.expandDims
expandDims' :: TensorType t => OpParams -> Tensor v1 t -> Tensor v2 Int32 -> Tensor Value t
expandDims' = CoreOps.expandDims'
-- | Helper function for reduction ops (translation of math_ops.reduced_shape).
reducedShape :: (OneOf '[ Int32, Int64 ] t1, OneOf '[ Int32, Int64 ] t2) =>
Tensor v1 t1 -> Tensor v2 t2 -> Tensor Value Int32

View File

@ -19,8 +19,7 @@
module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Functor.Identity (runIdentity)
import Lens.Family2 ((^.))
import Lens.Family2 ((^.), (.~))
import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph
( node )
@ -35,13 +34,12 @@ import TensorFlow.Build
, asGraphDef
, evalBuildT
, flushNodeBuffer
, hoistBuildT
, render
, withDevice
, colocateWith
, withNameScope
, opName
)
import TensorFlow.ControlFlow (named)
import TensorFlow.Types (unScalar)
import TensorFlow.Ops
( add
@ -49,13 +47,12 @@ import TensorFlow.Ops
, constant
, initializedVariable
, variable
, variable'
)
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Session
( build
, buildAnd
, run
( run
, runSession
, run_
)
@ -65,26 +62,16 @@ import Test.HUnit ((@=?))
import Google.Test (googleTest)
import qualified Data.Vector as V
-- | Test named behavior.
testNamed :: Test
testNamed = testCase "testNamed" $ do
let graph = named "foo" <$> variable [] >>= render :: Build (Tensor Ref Float)
-- | Test 'opName' behavior.
testOpName :: Test
testOpName = testCase "testOpName" $ do
let graph = variable' (opName .~ "foo") []
>>= render :: Build (Tensor Ref Float)
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"Variable" @=? (nodeDef ^. op)
"foo" @=? (nodeDef ^. name)
-- | Test named deRef behavior.
testNamedDeRef :: Test
testNamedDeRef = testCase "testNamedDeRef" $ do
let graph = named "foo" <$> do
v :: Tensor Ref Float <- variable []
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- buildAnd run graph
liftIO $ 5 @=? (unScalar out :: Float)
-- | Test that "run" will render and extend any pure ops that haven't already
-- been rendered.
testPureRender :: Test
@ -96,7 +83,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
testInitializedVariable :: Test
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do
(formula, reset) <- do
v <- initializedVariable 42
r <- assign v 24
return (1 `add` v, r)
@ -109,7 +96,7 @@ testInitializedVariable =
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float])
vector <- initializedVariable (constant [1] [42 :: Float])
result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float)
@ -122,33 +109,30 @@ testNameScoped = testCase "testNameScoped" $ do
"foo/Variable_0" @=? (nodeDef ^. name) -- TODO: Check prefix.
"Variable" @=? (nodeDef ^. op)
-- | Test combined named and nameScoped behavior.
-- | Test combined opName and nameScoped behavior.
testNamedAndScoped :: Test
testNamedAndScoped = testCase "testNamedAndScoped" $ do
let graph :: Build (Tensor Ref Float)
graph = withNameScope "foo1" ((named "bar1" <$> variable []) >>= render)
graph = withNameScope "foo1" (variable' (opName .~ "bar1") [])
>>= render
nodeDef :: NodeDef
nodeDef = head $ asGraphDef graph ^. node
"RefIdentity" @=? (nodeDef ^. op)
"Variable" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name)
-- | Lift a Build action into a context for HUnit to run.
liftBuild :: Build a -> BuildT IO a
liftBuild = hoistBuildT (return . runIdentity)
-- | Flush the node buffer and sort the nodes by name (for more stable tests).
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer
flushed field = sort . map field <$> flushNodeBuffer
-- | Test the interaction of rendering, CSE and scoping.
testRenderDedup :: Test
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
names <- flushed (^. name)
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
-- Render the nodes in a different scope, which should cause them
-- to be distinct from the previous ones.
liftBuild $ withNameScope "foo" renderNodes
withNameScope "foo" renderNodes
scopedNames <- flushed (^. name)
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
where
@ -165,7 +149,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
-- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation :: Test
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device))
liftIO $ [ ("Add_2","dev0")
, ("Const_1","dev0")
@ -182,8 +166,7 @@ main :: IO ()
main = googleTest [ testInitializedVariable
, testInitializedVariableShape
, testDeviceColocation
, testNamed
, testNamedDeRef
, testOpName
, testNameScoped
, testNamedAndScoped
, testPureRender

View File

@ -45,7 +45,7 @@ testDynamicPartitionStitchInverse (StitchExample numParts values partitions) =
restitch = CoreOps.dynamicStitch restitchIndices splitParts
in monadicIO $ run $ do
fromIntegral numParts @=? length splitParts
valuesOut <- TF.runSession $ TF.buildAnd TF.run $ return restitch
valuesOut <- TF.runSession $ TF.run restitch
V.fromList values @=? valuesOut
data StitchExample a = StitchExample Int64 [a] [Int32]

View File

@ -39,11 +39,6 @@ import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Nodes as TF
buildAndRun :: TF.Fetchable t a => TF.Build t -> IO a
buildAndRun = TF.runSession . TF.buildAnd TF.run
-- | Tries to perform a simple embedding lookup, with two partitions.
@ -61,9 +56,9 @@ testEmbeddingLookupHasRightShapeWithPartition =
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids
(values, shape) <- buildAndRun $ do
(values, shape) <- TF.runSession $ do
vs <- op
return (vs, TF.shape vs)
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3]
@ -87,9 +82,9 @@ testEmbeddingLookupHasRightShape =
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids
(values, shape) <- buildAndRun $ do
(values, shape) <- TF.runSession $ do
vs <- op
return (vs, TF.shape vs)
TF.run (vs, TF.shape vs)
-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [1, 2, 3]
@ -106,7 +101,6 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
let shape = TF.Shape [2]
gs <- TF.runSession $ do
grads <- TF.build $ do
let embShape = TF.Shape [2, 1]
let embeddingInit = [1, 20 ::Float]
let idValues = [1, 1 :: Int32]
@ -121,9 +115,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
loss = TF.mean twoNorm (TF.scalar (0 :: Int32))
grad <- fmap head (TF.gradients loss [embedding])
return $ \xs -> TF.runWithFeeds [TF.feed x xs] grad
grads (TF.encodeTensorData shape xVals :: TF.TensorData Float)
TF.runWithFeeds
[TF.feed x $ TF.encodeTensorData shape xVals]
grad
-- Gradients should be zero (or close)
assertAllClose gs (V.fromList ([0, 0 :: Float]))
@ -148,7 +142,7 @@ testEmbeddingLookupUndoesSplit
shapedValues = TF.constant shape values
in monadicIO $ run $ do
(shapeOut, got, want :: V.Vector a) <-
TF.runSession $ TF.buildAnd TF.run $ do
TF.runSession $ TF.run =<< do
embeddings <- embeddingLookup modShardedValues indicesVector
return (TF.cast (TF.shape embeddings), embeddings, directs)
-- Checks the explicitly documented invariant of embeddingLookup.

View File

@ -13,6 +13,7 @@
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Int (Int32)
@ -40,7 +41,7 @@ testGradientSimple = testCase "testGradientSimple" $ do
y = x*x + b
grads = TF.gradients y [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
[dx, db] <- TF.runSession $ grads >>= TF.run
6 @=? TF.unScalar dx
1 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
@ -91,7 +92,7 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
b = TF.scalar (4 :: Float)
grads = TF.gradients x [x, b]
-- Assert that the gradients are right.
[dx, db] <- TF.runSession $ TF.buildAnd TF.run grads
[dx, db] <- TF.runSession $ grads >>= TF.run
1 @=? TF.unScalar dx
0 @=? TF.unScalar db
-- Assert that the graph has the expected ops.
@ -113,11 +114,11 @@ testGradientDisconnected = testCase "testGradientDisconnected" $ do
-- Test that identical "stateful" ops work with createGraph.
testCreateGraphStateful :: Test
testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
[dx, dy] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx, dy] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
y :: TF.Tensor TF.Value Float <- TF.truncatedNormal shape
TF.gradients (x + y*3) [x, y]
TF.gradients (x + y*3) [x, y] >>= TF.run
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. These asserts are extra.
1 @=? TF.unScalar dx
@ -127,11 +128,11 @@ testCreateGraphStateful = testCase "testCreateGraphStateful" $ do
-- Test that name scopes work with createGraph.
testCreateGraphNameScopes :: Test
testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let shape = TF.constant (TF.Shape [1]) [1]
x :: TF.Tensor TF.Value Float <-
TF.withNameScope "foo" (TF.truncatedNormal shape)
TF.gradients x [x]
TF.gradients x [x] >>= TF.run
-- If this test fails, it will likely be caused by an exception within
-- `TF.gradients`. This assert is extra.
1 @=? TF.unScalar dx
@ -140,20 +141,20 @@ testCreateGraphNameScopes = testCase "testCreateGraphNameScopes" $ do
-- Test that createGraph can handle graphs with diamond shapes.
testDiamond :: Test
testDiamond = testCase "testDiamond" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1]
y = x*x
z = y*y
TF.gradients z [x]
TF.gradients z [x] >>= TF.run
(4 :: Float) @=? TF.unScalar dx
testMaxGradient :: Test
testMaxGradient = testCase "testMaxGradient" $ do
[dx] <- TF.runSession $ TF.buildAnd TF.run $ do
[dx] <- TF.runSession $ do
let x = TF.vector [1, 2, 3, 0, 1 :: Float]
y = TF.max x (0 :: TF.Tensor TF.Value Int32)
TF.gradients y [x]
TF.gradients y [x] >>= TF.run
V.fromList [0, 0, 1, 0, 0 :: Float] @=? dx

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
@ -19,6 +20,7 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int32, Int64)
import Google.Test (googleTest)
import Lens.Family2 ((.~))
import System.IO.Temp (withSystemTempDirectory)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
@ -27,7 +29,6 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF
@ -41,7 +42,7 @@ testSize = testCase "testSize" $ do
TF.Scalar (2 * 3 :: Int32) @=? x
eval :: TF.Fetchable t a => t -> IO a
eval = TF.runSession . TF.buildAnd TF.run . return
eval = TF.runSession . TF.run
-- | Confirms that the original example from Python code works.
testReducedShape :: Test
@ -54,22 +55,48 @@ testSaveRestore :: Test
testSaveRestore = testCase "testSaveRestore" $
withSystemTempDirectory "" $ \dirPath -> do
let path = B8.pack $ dirPath ++ "/checkpoint"
var :: TF.Build (TF.Tensor TF.Ref Float)
var :: TF.MonadBuild m => m (TF.Tensor TF.Ref Float)
var = TF.render =<<
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
TF.zeroInitializedVariable' (TF.opName .~ "a")
(TF.Shape [])
TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.assign v 134
TF.buildAnd TF.run_ $ TF.save path [v]
v <- var
TF.assign v 134 >>= TF.run_
TF.save path [v] >>= TF.run_
result <- TF.runSession $ do
v <- TF.build var
TF.buildAnd TF.run_ $ TF.restore path v
v <- var
TF.restore path v >>= TF.run_
TF.run v
liftIO $ TF.Scalar 134 @=? result
-- | Test that 'placeholder' is not CSE'd.
testPlaceholderCse :: Test
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
p1 <- TF.placeholder []
p2 <- TF.placeholder []
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
liftIO $ result @=? TF.Scalar 5
-- | Test that regular tensors can also be used for feeds, as long as they each
-- have a different name.
testScalarFeedCse :: Test
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
-- The second op is identical to the first other than its name; make sure
-- we don't aggressively CSE them together and prevent feeding them
-- separately.
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
let enc :: Float -> TF.TensorData Float
enc n = TF.encodeTensorData [] (V.fromList [n])
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
liftIO $ result @=? TF.Scalar 5
main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
, testPlaceholderCse
, testScalarFeedCse
]

View File

@ -25,13 +25,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0)
b <- TF.build (TF.initializedVariable 0)
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
@ -40,7 +40,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))

View File

@ -35,7 +35,7 @@ testTracing = do
loggedValue <- newEmptyMVar
TF.runSessionWithOptions
(def & TF.sessionTracer .~ putMVar loggedValue)
(TF.buildAnd TF.run_ (pure (TF.scalar (0 :: Float))))
(TF.run_ (TF.scalar (0 :: Float)))
tryReadMVar loggedValue >>=
maybe (assertFailure "Logging never happened") expectedFormat
where expectedFormat x =

View File

@ -36,7 +36,6 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF

View File

@ -12,67 +12,60 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Queues in TensorFlow graph. Very limited support for now.
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
import TensorFlow.Build (ControlNode, MonadBuild, build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor)
import TensorFlow.Types (TensorType, tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Tensor (Ref, Value, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)
-- | A queue carrying tuples. The underlying structure is more
-- versatile and can be made to support arbitrary tuples.
data Queue2 a b = Queue2 { handle :: Handle }
-- | A queue carrying tuples.
data Queue (as :: [*]) = Queue { handle :: Handle }
type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue.
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
=> Queue2 a b
-> Tensor v1 a
-> Tensor v2 b
-> Build ControlNode
enqueue q =
buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
(handle q)
enqueue :: forall as v m . (MonadBuild m, TensorTypes as)
=> Queue as
-> TensorList v as
-> m ControlNode
enqueue = CoreOps.queueEnqueue . handle
-- | Retrieves the values from the queue.
dequeue :: forall a b . (TensorType a, TensorType b)
=> Queue2 a b
-> Build (Tensor Ref a, Tensor Ref b)
-- ^ Dequeued tensors. They are paired in a sense
dequeue :: forall as m . (MonadBuild m, TensorTypes as)
=> Queue as
-> m (TensorList Value as)
-- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are
-- not consumed together.
dequeue q =
buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
(handle q)
dequeue = CoreOps.queueDequeue . handle
-- | Creates a new queue with the given capacity and shared name.
makeQueue2 :: forall a b . (TensorType a, TensorType b)
makeQueue :: forall as m . (MonadBuild m, TensorTypes as)
=> Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions.
-> Build (Queue2 a b)
makeQueue2 capacity sharedName = do
q <- buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)]
-> m (Queue as)
makeQueue capacity sharedName = do
q <- build $ buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity
)
group q >>= addInitializer
return (Queue2 q)
return (Queue q)
-- TODO(gnezdo): Figure out the closing story for queues.

View File

@ -39,6 +39,7 @@ Test-Suite QueueTest
, lens-family
, google-shim
, tensorflow
, tensorflow-core-ops
, tensorflow-ops
, tensorflow-queue
, test-framework

View File

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -20,13 +21,12 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import TensorFlow.Types (Scalar(..))
import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
import TensorFlow.Ops (scalar)
import TensorFlow.Queue
import TensorFlow.Session
( asyncProdNodes
, build
, buildAnd
, run
, runSession
, run_
@ -39,42 +39,50 @@ import qualified Data.ByteString as BS
-- | Test basic queue behaviors.
testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
buildAnd run_ (enqueue q 42 (scalar "Hi"))
x <- buildAnd run (dequeue q)
liftIO $ (Scalar 42, Scalar "Hi") @=? x
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
run_ =<< enqueue q (42 :/ scalar "Hi" :/ Nil)
x <- run =<< dequeue q
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
buildAnd run_ (enqueue q 56 (scalar "Bar"))
y <- buildAnd run (dequeue q)
liftIO $ (Scalar 56, Scalar "Bar") @=? y
run_ =<< enqueue q (56 :/ scalar "Bar" :/ Nil)
y <- run =<< dequeue q
-- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
-- or else allow the types to be determined by future use of the fetched
-- value.
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
liftIO $ expected @=? y
-- | Test queue pumping.
testPump :: Test
testPump = testCase "testPump" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
(,) <$> dequeue q
<*> enqueue q 31 (scalar "Baz")
<*> enqueue q (31 :/ scalar "Baz" :/ Nil)
-- This is a realistic use. The pump inputs are pre-bound to some
-- nodes that produce values when pumped (e.g. read from a
-- file).
run_ (pump, pump)
(x, y) <- run (deq, deq)
liftIO $ (Scalar 31, Scalar "Baz") @=? x
liftIO $ (Scalar 31, Scalar "Baz") @=? y
let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
liftIO $ expected @=? x
liftIO $ expected @=? y
testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
(deq, pump) <- do
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
(,) <$> dequeue q
<*> enqueue q 10 (scalar "Async")
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
-- Pumps the queue until canceled by runSession exiting.
asyncProdNodes pump
-- Picks up a couple values and verifies they are as expected.
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
run deq >>= liftIO . (expected @=?)
run deq >>= liftIO . (expected @=?)
main :: IO ()
main = googleTest [ testBasic

View File

@ -37,6 +37,7 @@ module TensorFlow.Build
, renderedNodeDefs
, BuildT
, Build
, MonadBuild(..)
, addInitializer
, hoistBuildT
, evalBuildT
@ -212,9 +213,16 @@ runBuildT (BuildT f) = runStateT f initGraphState
evalBuildT :: Monad m => BuildT m a -> m a
evalBuildT (BuildT f) = evalStateT f initGraphState
-- | Lift a 'Build' action into a monad, including any explicit op renderings.
class Monad m => MonadBuild m where
build :: Build a -> m a
instance Monad m => MonadBuild (BuildT m) where
build = hoistBuildT $ return . runIdentity
-- | Get all the NodeDefs that have accumulated so far, and clear that buffer.
flushNodeBuffer :: Monad m => BuildT m [NodeDef]
flushNodeBuffer = do
flushNodeBuffer :: MonadBuild m => m [NodeDef]
flushNodeBuffer = build $ do
ns <- use nodeBuffer
nodeBuffer .= []
return ns
@ -229,8 +237,8 @@ flushInitializers = do
-- | Registers the given node to be executed before the next
-- 'TensorFlow.Session.run'.
addInitializer :: ControlNode -> Build ()
addInitializer (ControlNode o) = do
addInitializer :: MonadBuild m => ControlNode -> m ()
addInitializer (ControlNode o) = build $ do
i <- getOrAddOp o
initializationNodes %= (i:)
@ -242,8 +250,8 @@ asGraphDef b = def & node .~ gs ^. nodeBuffer
gs = snd $ runIdentity $ runBuildT b
-- TODO: check against existing nodes for conflicts?
addGraphDef :: GraphDef -> Build ()
addGraphDef g = nodeBuffer <>= g ^. node
addGraphDef :: MonadBuild m => GraphDef -> m ()
addGraphDef g = build $ nodeBuffer <>= g ^. node
-- | Render the given op if it hasn't been rendered already, and return its
-- name.
@ -318,34 +326,34 @@ renderOutput (Output (OutputIx i) o) = do
-- | Modify some part of the state, run an action, and restore the state
-- after that action is done.
withStateLens :: MonadState s m => Lens' s a -> (a -> a) -> m b -> m b
withStateLens :: MonadBuild m => Lens' GraphState a -> (a -> a) -> m b -> m b
withStateLens accessor f act = do
old <- use accessor
accessor %= f
old <- build $ use accessor
build $ accessor %= f
result <- act
accessor .= old
build $ accessor .= old
return result
-- | Set a device for all nodes rendered in the given 'Build' action
-- (unless further overridden by another use of withDevice).
withDevice :: Maybe Device -> Build a -> Build a
withDevice :: MonadBuild m => Maybe Device -> m a -> m a
withDevice d = withStateLens defaultDevice (const d)
-- | Places all nodes rendered in the given 'Build' action on the same
-- device as the given Tensor (see also 'withDevice'). Make sure that
-- the action has side effects of rendering the desired tensors. A pure
-- return would not have the desired effect.
colocateWith :: forall a v b . Tensor v b -> Build a -> Build a
colocateWith :: MonadBuild m => forall a v b . Tensor v b -> m a -> m a
colocateWith t x = do
d <- Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
d <- build $ Device . (^. device) <$> resolveOp (t ^. tensorOutput . outputOp)
withDevice (Just d) x
-- | Prepend a scope to all nodes rendered in the given 'Build' action.
withNameScope :: Text -> Build a -> Build a
withNameScope :: MonadBuild m => Text -> m a -> m a
withNameScope s = withStateLens currentScope (Scope s :)
-- | Add control inputs to all nodes rendered in the given 'Build' action.
withNodeDependencies :: Set NodeName -> Build a -> Build a
withNodeDependencies :: MonadBuild m => Set NodeName -> m a -> m a
withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- | Render a 'Tensor', fixing its name, scope, device and control inputs from
@ -355,8 +363,8 @@ withNodeDependencies nodes = withStateLens defaultControlInputs (<> nodes)
-- This operation is idempotent; @render >=> render === render@. However,
-- rendering a (previously un-rendered) 'Tensor' in two different contexts
-- may result in two different 'Tensor's.
render :: Tensor v a -> Build (Tensor v a)
render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
render :: MonadBuild m => Tensor v a -> m (Tensor v a)
render = build . tensorOutput (outputOp $ fmap Rendered . resolveOp)
-- | Render a 'Tensor' and get its node's name.
renderNodeName :: Tensor v a -> Build NodeName

View File

@ -13,6 +13,8 @@
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
@ -21,6 +23,7 @@ module TensorFlow.BuildOp
, buildOp
, buildListOp
, eqLengthGuard
, OpParams
)
where
@ -33,6 +36,7 @@ import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show
@ -98,6 +102,22 @@ instance OpResult (Tensor Ref a) where
instance OpResult ControlNode where
toResult = ControlNode <$> ask
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- tensorResult v
ts <- loop ls
return (t :/ ts)
instance TensorTypes as => OpResult (TensorList Value as) where
toResult = tensorListResult ValueKind
instance TensorTypes as => OpResult (TensorList Ref as) where
toResult = tensorListResult RefKind
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get
@ -159,6 +179,12 @@ instance BuildOp (Tensor Value a) where
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Value as) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Ref as) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
@ -199,6 +225,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
instance BuildOp f => BuildOp (TensorList v as -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
@ -209,3 +239,7 @@ eqLengthGuard = all eachOk
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)
-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef

View File

@ -22,27 +22,21 @@ module TensorFlow.ControlFlow
withControlDependencies
, group
-- * Operations
, identity
, noOp
, named
) where
import qualified Data.Set as Set
import Data.Text (Text)
import Lens.Family2 ((&), (^.), (.~))
import Lens.Family2 ((&), (.~))
import TensorFlow.BuildOp
import TensorFlow.Build
import TensorFlow.Nodes
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
-- | Modify a 'Build' action, such that all new ops rendered in it will depend
-- on the nodes in the first argument.
withControlDependencies :: Nodes t => t -> Build a -> Build a
withControlDependencies :: (MonadBuild m, Nodes t) => t -> m a -> m a
withControlDependencies deps act = do
nodes <- getNodes deps
nodes <- build $ getNodes deps
withNodeDependencies nodes act
-- TODO(judahjacobson): Reimplement withDependencies.
@ -51,37 +45,12 @@ withControlDependencies deps act = do
--
-- When this op finishes, all ops in the input @n@ have finished. This op has
-- no output.
group :: Nodes t => t -> Build ControlNode
group :: (MonadBuild m, Nodes t) => t -> m ControlNode
group deps = do
nodes <- Set.toList <$> getNodes deps
nodes <- build $ Set.toList <$> getNodes deps
-- TODO: slicker way
return $ buildOp $ opDef "NoOp" & opControlInputs .~ nodes
-- | Returns a 'Tensor' with the same shape and contents as the input.
identity :: TensorType a => Tensor v a -> Tensor v a
identity = namedIdentity implicitName
-- | Returns a 'Tensor' with a given name and the same shape and contents as
-- the input.
--
-- TODO(judahjacobson): This breaks when used with uninitialize @Tensor Ref@s,
-- since @RefIdentity@ doesn't have SetAllowsUninitializedInput(). Look into
-- whether we can change that op.
named :: TensorType a => Text -> Tensor v a -> Tensor v a
named = namedIdentity . explicitName
-- | An internal version of "identity" that allows setting the name
-- of the output Tensor.
namedIdentity :: forall a v . TensorType a
=> PendingNodeName -> Tensor v a -> Tensor v a
namedIdentity n t = case t ^. tensorKind of
ValueKind -> buildOp (opDefWithName n "Identity" & setTypeAttr) t
RefKind -> buildOp (opDefWithName n "RefIdentity" & setTypeAttr) t
where
setTypeAttr = opAttr "T" .~ tensorType (undefined :: a)
-- | Does nothing. Only useful as a placeholder for control edges.
noOp :: ControlNode
noOp = buildOp $ opDef "NoOp"

View File

@ -31,8 +31,7 @@ module TensorFlow.Core
, runSession
, runSessionWithOptions
-- ** Building graphs
, build
, buildAnd
, MonadBuild(..)
-- ** Running graphs
, Fetchable
, Nodes
@ -51,14 +50,14 @@ module TensorFlow.Core
, render
, asGraphDef
, addGraphDef
, opName
, opAttr
-- * Tensor
, ControlNode
, Tensor
, Value
, Ref
, TensorKind(..)
, tensorAttr
, value
, tensorFromName
-- ** Element types
@ -75,12 +74,10 @@ module TensorFlow.Core
, Device(..)
, withDevice
, withNameScope
, named
-- ** Dependencies
, withControlDependencies
, group
-- ** Misc
, identity
, noOp
) where

View File

@ -12,15 +12,18 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3)
import Data.Functor.Identity (Identity)
import Data.Map.Strict (Map)
import Data.Monoid ((<>))
import Data.Set (Set)
@ -96,6 +99,19 @@ instance Nodes ControlNode where
instance a ~ () => Fetchable ControlNode a where
getFetch _ = return $ pure ()
instance Nodes (ListOf f '[]) where
getNodes _ = return Set.empty
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
getFetch _ = return $ pure Nil
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)

View File

@ -124,6 +124,9 @@ data OpDef = OpDef
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
instance IsString PendingNodeName where
fromString = ExplicitName . fromString
-- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).

View File

@ -26,8 +26,7 @@ module TensorFlow.Session (
sessionTracer,
runSession,
runSessionWithOptions,
build,
buildAnd,
MonadBuild(..),
extend,
addGraphDef,
run,
@ -44,7 +43,6 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
import Data.ByteString (ByteString)
import Data.Default (Default, def)
import Data.Functor.Identity (runIdentity)
import Data.Monoid ((<>))
import Data.ProtoLens (showMessage)
import Data.Set (Set)
@ -124,10 +122,8 @@ runSessionWithOptions options (Session m) =
FFI.setSessionTarget (options ^. sessionTarget) opt
FFI.setSessionConfig (options ^. sessionConfig) opt
-- | Lift a 'Build' action into a 'Session', including any explicit op
-- renderings.
build :: Build a -> Session a
build = Session . lift . hoistBuildT (return . runIdentity)
instance MonadBuild Session where
build = Session . lift . build
-- | Add all pending rendered nodes to the TensorFlow graph and runs
-- any pending initializers.
@ -147,13 +143,6 @@ extend = do
unless (null initializers) $
void $ liftIO $ FFI.run session [] [] (toNodeNames initializers)
-- | Helper combinator for doing something with the result of a 'Build' action.
-- Example usage:
--
-- > buildAnd run :: Fetchable t a => Build t -> Session a
buildAnd :: (a -> Session b) -> Build a -> Session b
buildAnd f m = build m >>= f
-- | Run a subgraph 't', rendering any dependent nodes that aren't already
-- rendered, and fetch the corresponding values for 'a'.
run :: Fetchable t a => t -> Session a

View File

@ -12,20 +12,29 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Tensor where
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal')
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
import TensorFlow.Types (TensorData(..), Attribute)
import TensorFlow.Output (Output)
import TensorFlow.Types
( TensorData(..)
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation.
@ -52,15 +61,6 @@ tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
-- TODO: Come up with a better API for handling attributes.
-- | Lens for the attributes of a tensor.
--
-- Only valid if the tensor has not yet been rendered. If the tensor has been
-- rendered, the traversal will be over nothing (nothing can be read or
-- written).
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a
@ -83,3 +83,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts

View File

@ -13,9 +13,11 @@
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
@ -36,23 +38,35 @@ module TensorFlow.Types
, Shape(..)
, protoShape
, Attribute(..)
, DataType(..)
-- * Lists
, ListOf(..)
, List
, (/:/)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
, fromTensorTypeList
, fromTensorTypes
-- * Type constraints
, OneOf
, type (/=)
, OneOfs
-- ** Implementation of constraints
, TypeError
, ExcludedCase
, TensorTypes
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
@ -376,6 +390,44 @@ instance Attribute [DataType] where
instance Attribute [Int64] where
attrLens = list . i
-- | A heterogeneous list type.
data ListOf f as where
Nil :: ListOf f '[]
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
infixr 5 :/
type family All f as :: Constraint where
All f '[] = ()
All f (a ': as) = (f a, All f as)
type family Map f as where
Map f '[] = '[]
Map f (a ': as) = f a ': Map f as
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == Nil = True
(x :/ xs) == (y :/ ys) = x == y && xs == ys
-- Newer versions of GHC use the GADT to tell that the previous cases are
-- exhaustive.
#if _GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec _ Nil = showString "Nil"
showsPrec d (x :/ xs) = showParen (d > 10)
$ showsPrec 6 x . showString " :/ "
. showsPrec 6 xs
type List = ListOf Identity
-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
(/:/) = (:/) . Identity
infixr 5 /:/
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
@ -393,13 +445,38 @@ instance Attribute [Int64] where
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
-- Assert `TensorTypes ts` to make error messages a little better.
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
-- | A check that the input is a list of 'TensorType's.
-- Helps improve error messages when using 'OneOf'.
type OneOfs ts as = (TensorTypes as, TensorTypes ts,
NoneOfs (AllTensorTypes \\ ts) as)
type family NoneOfs ts as :: Constraint where
NoneOfs ts '[] = ()
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
data TensorTypeProxy a where
TensorTypeProxy :: TensorType a => TensorTypeProxy a
type TensorTypeList = ListOf TensorTypeProxy
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
= tensorType (undefined :: t) : fromTensorTypeList ts
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
class TensorTypes (ts :: [*]) where
instance TensorTypes '[]
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
tensorTypes :: TensorTypeList ts
instance TensorTypes '[] where
tensorTypes = Nil
-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
tensorTypes = TensorTypeProxy :/ tensorTypes
-- | A constraint checking that two types are different.
type family a /= b :: Constraint where