mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
MNIST Main compiles but is broken.
This commit is contained in:
parent
2b5e41ffeb
commit
1677c346eb
11 changed files with 69 additions and 48 deletions
|
@ -15,7 +15,7 @@
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE OverloadedLists #-}
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
|
|
||||||
import Control.Monad (zipWithM, when, forM_)
|
import Control.Monad (forM_, when)
|
||||||
import Control.Monad.IO.Class (liftIO)
|
import Control.Monad.IO.Class (liftIO)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.List (genericLength)
|
import Data.List (genericLength)
|
||||||
|
@ -30,6 +30,7 @@ import qualified TensorFlow.Ops as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
import qualified TensorFlow.Tensor as TF
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
import TensorFlow.Examples.MNIST.InputData
|
import TensorFlow.Examples.MNIST.InputData
|
||||||
import TensorFlow.Examples.MNIST.Parse
|
import TensorFlow.Examples.MNIST.Parse
|
||||||
|
@ -64,6 +65,7 @@ data Model = Model {
|
||||||
|
|
||||||
createModel :: TF.Build Model
|
createModel :: TF.Build Model
|
||||||
createModel = do
|
createModel = do
|
||||||
|
let rd = CoreOps.readVariableOp
|
||||||
-- Use -1 batch size to support variable sized batches.
|
-- Use -1 batch size to support variable sized batches.
|
||||||
let batchSize = -1
|
let batchSize = -1
|
||||||
-- Inputs.
|
-- Inputs.
|
||||||
|
@ -73,13 +75,14 @@ createModel = do
|
||||||
hiddenWeights <-
|
hiddenWeights <-
|
||||||
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
TF.initializedVariable =<< randomParam numPixels [numPixels, numUnits]
|
||||||
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
hiddenBiases <- TF.zeroInitializedVariable [numUnits]
|
||||||
let hiddenZ = (images `TF.matMul` hiddenWeights) `TF.add` hiddenBiases
|
let hiddenZ = (images `TF.matMul` rd hiddenWeights)
|
||||||
|
`TF.add` rd hiddenBiases
|
||||||
let hidden = TF.relu hiddenZ
|
let hidden = TF.relu hiddenZ
|
||||||
-- Logits.
|
-- Logits.
|
||||||
logitWeights <-
|
logitWeights <-
|
||||||
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
TF.initializedVariable =<< randomParam numUnits [numUnits, numLabels]
|
||||||
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
logitBiases <- TF.zeroInitializedVariable [numLabels]
|
||||||
let logits = (hidden `TF.matMul` logitWeights) `TF.add` logitBiases
|
let logits = (hidden `TF.matMul` rd logitWeights) `TF.add` rd logitBiases
|
||||||
predict <- TF.render $ TF.cast $
|
predict <- TF.render $ TF.cast $
|
||||||
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
TF.argMax (TF.softmax logits) (TF.scalar (1 :: LabelType))
|
||||||
|
|
||||||
|
@ -89,11 +92,11 @@ createModel = do
|
||||||
loss =
|
loss =
|
||||||
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
|
||||||
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
|
||||||
grads <- TF.gradients loss params
|
grads <- TF.gradients loss (map rd params)
|
||||||
|
|
||||||
let lr = TF.scalar 0.00001
|
let lr = TF.scalar (-0.00001 :: Float) -- Negative to make it descend.
|
||||||
applyGrad param grad = TF.assign param $ param `TF.sub` (lr * grad)
|
applyGrad var grad = CoreOps.assignVariableOp var (lr * grad)
|
||||||
trainStep <- TF.group =<< zipWithM applyGrad params grads
|
trainStep <- TF.group (zipWith applyGrad params grads)
|
||||||
|
|
||||||
let correctPredictions = TF.equal predict labels
|
let correctPredictions = TF.equal predict labels
|
||||||
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
|
||||||
|
|
|
@ -45,6 +45,7 @@ executable Main
|
||||||
, lens-family
|
, lens-family
|
||||||
, proto-lens
|
, proto-lens
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-mnist
|
, tensorflow-mnist
|
||||||
, tensorflow-mnist-input-data
|
, tensorflow-mnist-input-data
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
|
|
|
@ -654,6 +654,7 @@ opGrad "ZerosLike" _ _ _ = [Nothing]
|
||||||
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
opGrad "Const" _ _ _ = [Nothing, Nothing]
|
||||||
opGrad "Placeholder" _ _ _ = []
|
opGrad "Placeholder" _ _ _ = []
|
||||||
opGrad "Variable" _ _ _ = []
|
opGrad "Variable" _ _ _ = []
|
||||||
|
opGrad "ReadVariableOp" _ _ _ = [Nothing, Nothing]
|
||||||
|
|
||||||
opGrad n nodeDef ins grads =
|
opGrad n nodeDef ins grads =
|
||||||
error $ "no gradient implemented for " ++
|
error $ "no gradient implemented for " ++
|
||||||
|
@ -699,6 +700,7 @@ numOutputs o =
|
||||||
"Transpose" -> 1
|
"Transpose" -> 1
|
||||||
"TruncatedNormal" -> 1
|
"TruncatedNormal" -> 1
|
||||||
"Variable" -> 1
|
"Variable" -> 1
|
||||||
|
"ReadVariableOp" -> 2
|
||||||
"ZerosLike" -> 1
|
"ZerosLike" -> 1
|
||||||
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
_ -> error $ "numOuputs not implemented for " ++ show (o ^. op)
|
||||||
|
|
||||||
|
|
|
@ -121,8 +121,7 @@ import qualified Proto.Tensorflow.Core.Framework.TensorShape
|
||||||
as TensorShape
|
as TensorShape
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.BuildOp
|
import TensorFlow.BuildOp
|
||||||
import TensorFlow.ControlFlow (group)
|
import TensorFlow.Output (ResourceHandle, unNodeName)
|
||||||
import TensorFlow.Output (unNodeName)
|
|
||||||
import TensorFlow.Tensor
|
import TensorFlow.Tensor
|
||||||
import TensorFlow.Types
|
import TensorFlow.Types
|
||||||
|
|
||||||
|
@ -176,37 +175,30 @@ assign = buildOp $ opDef "Assign"
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
initializedVariable :: forall a . TensorType a
|
initializedVariable :: forall a . TensorType a
|
||||||
=> Tensor Value a -> Build (Tensor Ref a)
|
=> Tensor Value a -> Build (ResourceHandle a)
|
||||||
initializedVariable initializer = do
|
initializedVariable initializer = do
|
||||||
v <- variable [] -- The shape is not known initially.
|
let v = CoreOps.varHandleOp
|
||||||
(i :: Tensor Ref a) <-
|
& resourceHandleAttr "shape" .~ (Shape [])
|
||||||
buildOp (opDef "Assign"
|
addInitializer (CoreOps.createVariableOp v initializer)
|
||||||
& opAttr "T" .~ tensorType (undefined :: a)
|
|
||||||
& opAttr "use_locking" .~ True
|
|
||||||
& opAttr "validate_shape" .~ False
|
|
||||||
)
|
|
||||||
v initializer
|
|
||||||
addInitializer =<< group i
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
-- | Creates a zero-initialized variable with the given shape.
|
-- | Creates a zero-initialized variable with the given shape.
|
||||||
zeroInitializedVariable
|
zeroInitializedVariable
|
||||||
:: (TensorType a, Num a) =>
|
:: (TensorType a, Num a) => TensorFlow.Types.Shape -> Build (ResourceHandle a)
|
||||||
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
|
|
||||||
zeroInitializedVariable = initializedVariable . zeros
|
zeroInitializedVariable = initializedVariable . zeros
|
||||||
|
|
||||||
-- TODO: Support heterogeneous list of tensors.
|
-- TODO: Support heterogeneous list of tensors.
|
||||||
save :: forall a v . TensorType a
|
save :: forall a . TensorType a
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> [Tensor v a] -- ^ Tensors to save.
|
-> [ResourceHandle a] -- ^ Tensors to save.
|
||||||
-> Build ControlNode
|
-> Build ControlNode
|
||||||
save path xs = do
|
save path xs = do
|
||||||
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
|
||||||
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
|
names <- mapM (fmap toByteStringTensor . renderResourceHandle) xs
|
||||||
let types = replicate (length xs) (tensorType (undefined :: a))
|
let types = replicate (length xs) (tensorType (undefined :: a))
|
||||||
let saveOp = buildOp $ opDef "Save"
|
let saveOp = buildOp $ opDef "Save"
|
||||||
& opAttr "T" .~ types
|
& opAttr "T" .~ types
|
||||||
saveOp (scalar path) (CoreOps.pack names) xs
|
saveOp (scalar path) (CoreOps.pack names) (map CoreOps.readVariableOp xs)
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
--
|
--
|
||||||
|
@ -215,21 +207,21 @@ save path xs = do
|
||||||
restoreFromName :: forall a . TensorType a
|
restoreFromName :: forall a . TensorType a
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> ByteString -- ^ Tensor name override.
|
-> ByteString -- ^ Tensor name override.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> Build (Tensor Value a)
|
||||||
-> Build ControlNode
|
restoreFromName path name = do
|
||||||
restoreFromName path name x = do
|
|
||||||
let restoreOp = buildOp $ opDef "Restore"
|
let restoreOp = buildOp $ opDef "Restore"
|
||||||
& opAttr "dt" .~ tensorType (undefined :: a)
|
& opAttr "dt" .~ tensorType (undefined :: a)
|
||||||
group =<< assign x (restoreOp (scalar path) (scalar name) :: Tensor Value a)
|
restoreOp (scalar path) (scalar name)
|
||||||
|
|
||||||
|
|
||||||
-- | Restore a tensor's value from a checkpoint file.
|
-- | Restore a tensor's value from a checkpoint file.
|
||||||
restore :: forall a . TensorType a
|
restore :: forall a . TensorType a
|
||||||
=> ByteString -- ^ File path.
|
=> ByteString -- ^ File path.
|
||||||
-> Tensor Ref a -- ^ Tensor to restore.
|
-> ResourceHandle a
|
||||||
-> Build ControlNode
|
-> Build ControlNode
|
||||||
restore path x = do
|
restore path x = do
|
||||||
name <- encodeUtf8 . unNodeName <$> renderNodeName x
|
name <- encodeUtf8 . unNodeName <$> renderResourceHandle x
|
||||||
restoreFromName path name x
|
CoreOps.assignVariableOp x <$> restoreFromName path name
|
||||||
|
|
||||||
-- | Create a constant tensor.
|
-- | Create a constant tensor.
|
||||||
--
|
--
|
||||||
|
@ -253,12 +245,13 @@ constant (Shape shape') values
|
||||||
typedNode :: TensorProto
|
typedNode :: TensorProto
|
||||||
typedNode = def
|
typedNode = def
|
||||||
& dtype .~ nodeType
|
& dtype .~ nodeType
|
||||||
|
-- Use shapeToProto from Types.hs
|
||||||
& tensorShape.TensorShape.dim .~
|
& tensorShape.TensorShape.dim .~
|
||||||
[def & TensorShape.size .~ x | x <- shape']
|
[def & TensorShape.size .~ x | x <- shape']
|
||||||
& tensorVal .~ values
|
& tensorVal .~ values
|
||||||
|
|
||||||
-- | Reshape a N-D tensor down to a scalar.
|
-- | Reshape a N-D tensor down to a scalar.
|
||||||
--
|
--
|
||||||
-- See `TensorFlow.GenOps.Core.reshape`.
|
-- See `TensorFlow.GenOps.Core.reshape`.
|
||||||
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
scalarize :: (TensorType a) => Tensor v a -> Tensor Value a
|
||||||
scalarize t = CoreOps.reshape t (vector scalarShape)
|
scalarize t = CoreOps.reshape t (vector scalarShape)
|
||||||
|
|
|
@ -41,6 +41,7 @@ Test-Suite BuildTest
|
||||||
, lens-family
|
, lens-family
|
||||||
, google-shim
|
, google-shim
|
||||||
, tensorflow
|
, tensorflow
|
||||||
|
, tensorflow-core-ops
|
||||||
, tensorflow-ops
|
, tensorflow-ops
|
||||||
, tensorflow-proto
|
, tensorflow-proto
|
||||||
, test-framework
|
, test-framework
|
||||||
|
|
|
@ -64,6 +64,7 @@ import Test.Framework.Providers.HUnit (testCase)
|
||||||
import Test.HUnit ((@=?))
|
import Test.HUnit ((@=?))
|
||||||
import Google.Test (googleTest)
|
import Google.Test (googleTest)
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
|
|
||||||
-- | Test named behavior.
|
-- | Test named behavior.
|
||||||
testNamed :: Test
|
testNamed :: Test
|
||||||
|
@ -96,13 +97,12 @@ testPureRender = testCase "testPureRender" $ runSession $ do
|
||||||
testInitializedVariable :: Test
|
testInitializedVariable :: Test
|
||||||
testInitializedVariable =
|
testInitializedVariable =
|
||||||
testCase "testInitializedVariable" $ runSession $ do
|
testCase "testInitializedVariable" $ runSession $ do
|
||||||
(formula, reset) <- build $ do
|
(v, formula) <- build $ do
|
||||||
v <- initializedVariable 42
|
v <- initializedVariable 42
|
||||||
r <- assign v 24
|
return (v, 1 `add` CoreOps.readVariableOp v)
|
||||||
return (1 `add` v, r)
|
|
||||||
result <- run formula
|
result <- run formula
|
||||||
liftIO $ 43 @=? (unScalar result :: Float)
|
liftIO $ 43 @=? (unScalar result :: Float)
|
||||||
run_ reset -- Updates v to a different value
|
run_ (CoreOps.assignVariableOp v 24)
|
||||||
rerunResult <- run formula
|
rerunResult <- run formula
|
||||||
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
liftIO $ 25 @=? (unScalar rerunResult :: Float)
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ testInitializedVariableShape :: Test
|
||||||
testInitializedVariableShape =
|
testInitializedVariableShape =
|
||||||
testCase "testInitializedVariableShape" $ runSession $ do
|
testCase "testInitializedVariableShape" $ runSession $ do
|
||||||
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
vector <- build $ initializedVariable (constant [1] [42 :: Float])
|
||||||
result <- run vector
|
result <- run (CoreOps.readVariableOp vector)
|
||||||
liftIO $ [42] @=? (result :: V.Vector Float)
|
liftIO $ [42] @=? (result :: V.Vector Float)
|
||||||
|
|
||||||
-- | Test nameScoped behavior.
|
-- | Test nameScoped behavior.
|
||||||
|
|
|
@ -112,8 +112,9 @@ testEmbeddingLookupGradients = testCase "testEmbeddingLookupGradients" $ do
|
||||||
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
let ids = TF.constant (TF.Shape [1, 2]) idValues
|
||||||
|
|
||||||
x <- TF.placeholder (TF.Shape [2])
|
x <- TF.placeholder (TF.Shape [2])
|
||||||
embedding <- TF.initializedVariable
|
embedding <- CoreOps.readVariableOp <$>
|
||||||
=<< TF.render (TF.constant embShape embeddingInit)
|
(TF.initializedVariable
|
||||||
|
=<< TF.render (TF.constant embShape embeddingInit))
|
||||||
|
|
||||||
op <- embeddingLookup [embedding] ids
|
op <- embeddingLookup [embedding] ids
|
||||||
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
let twoNorm = CoreOps.square $ TF.abs (op - x)
|
||||||
|
|
|
@ -28,10 +28,11 @@ import qualified Data.ByteString.Char8 as B8
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
import qualified TensorFlow.Build as TF
|
import qualified TensorFlow.Build as TF
|
||||||
import qualified TensorFlow.ControlFlow as TF
|
import qualified TensorFlow.ControlFlow as TF
|
||||||
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import qualified TensorFlow.Nodes as TF
|
import qualified TensorFlow.Nodes as TF
|
||||||
import qualified TensorFlow.Ops as TF
|
import qualified TensorFlow.Ops as TF
|
||||||
|
import qualified TensorFlow.Output as TF
|
||||||
import qualified TensorFlow.Session as TF
|
import qualified TensorFlow.Session as TF
|
||||||
import qualified TensorFlow.Tensor as TF
|
|
||||||
import qualified TensorFlow.Types as TF
|
import qualified TensorFlow.Types as TF
|
||||||
|
|
||||||
-- | Test that one can easily determine number of elements in the tensor.
|
-- | Test that one can easily determine number of elements in the tensor.
|
||||||
|
@ -54,17 +55,16 @@ testSaveRestore :: Test
|
||||||
testSaveRestore = testCase "testSaveRestore" $
|
testSaveRestore = testCase "testSaveRestore" $
|
||||||
withSystemTempDirectory "" $ \dirPath -> do
|
withSystemTempDirectory "" $ \dirPath -> do
|
||||||
let path = B8.pack $ dirPath ++ "/checkpoint"
|
let path = B8.pack $ dirPath ++ "/checkpoint"
|
||||||
var :: TF.Build (TF.Tensor TF.Ref Float)
|
var :: TF.Build (TF.ResourceHandle Float)
|
||||||
var = TF.render =<<
|
var = TF.zeroInitializedVariable (TF.Shape [])
|
||||||
TF.named "a" <$> TF.zeroInitializedVariable (TF.Shape [])
|
|
||||||
TF.runSession $ do
|
TF.runSession $ do
|
||||||
v <- TF.build var
|
v <- TF.build var
|
||||||
TF.buildAnd TF.run_ $ TF.assign v 134
|
TF.buildAnd TF.run_ $ TF.group $ CoreOps.assignVariableOp v 134
|
||||||
TF.buildAnd TF.run_ $ TF.save path [v]
|
TF.buildAnd TF.run_ $ TF.save path [v]
|
||||||
result <- TF.runSession $ do
|
result <- TF.runSession $ do
|
||||||
v <- TF.build var
|
v <- TF.build var
|
||||||
TF.buildAnd TF.run_ $ TF.restore path v
|
TF.buildAnd TF.run_ $ TF.restore path v
|
||||||
TF.run v
|
TF.run (CoreOps.readVariableOp v)
|
||||||
liftIO $ TF.Scalar 134 @=? result
|
liftIO $ TF.Scalar 134 @=? result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ module TensorFlow.Build
|
||||||
, GraphState
|
, GraphState
|
||||||
, render
|
, render
|
||||||
, renderNodeName
|
, renderNodeName
|
||||||
|
, renderResourceHandle
|
||||||
, renderedNodeDefs
|
, renderedNodeDefs
|
||||||
, BuildT
|
, BuildT
|
||||||
, Build
|
, Build
|
||||||
|
@ -361,6 +362,9 @@ render = tensorOutput $ outputOp $ fmap Rendered . resolveOp
|
||||||
renderNodeName :: Tensor v a -> Build NodeName
|
renderNodeName :: Tensor v a -> Build NodeName
|
||||||
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
renderNodeName t = getOrAddOp (t ^. tensorOutput . outputOp)
|
||||||
|
|
||||||
|
renderResourceHandle :: ResourceHandle a -> Build NodeName
|
||||||
|
renderResourceHandle (ResourceHandle r) = getOrAddOp (r ^. outputOp)
|
||||||
|
|
||||||
-- | Records the given summary action in Build for retrieval with
|
-- | Records the given summary action in Build for retrieval with
|
||||||
-- 'collectAllSummaries'. The summary op is required to produce a
|
-- 'collectAllSummaries'. The summary op is required to produce a
|
||||||
-- Summary protocol buffer in string form. For safety, use the
|
-- Summary protocol buffer in string form. For safety, use the
|
||||||
|
|
|
@ -38,6 +38,7 @@ module TensorFlow.Output
|
||||||
, outputOp
|
, outputOp
|
||||||
, PendingNodeName(..)
|
, PendingNodeName(..)
|
||||||
, ResourceHandle(..)
|
, ResourceHandle(..)
|
||||||
|
, resourceHandleOutput
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
@ -160,4 +161,7 @@ instance IsString Output where
|
||||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||||
-- resources are variables. The type parameter corresponds to the
|
-- resources are variables. The type parameter corresponds to the
|
||||||
-- dtype of the tensor held in the variable.
|
-- dtype of the tensor held in the variable.
|
||||||
newtype ResourceHandle a = ResourceHandle Output
|
newtype ResourceHandle a = ResourceHandle { unResourceHandle :: Output }
|
||||||
|
|
||||||
|
resourceHandleOutput :: Lens' (ResourceHandle a) Output
|
||||||
|
resourceHandleOutput = lens unResourceHandle (\_ x -> ResourceHandle x)
|
||||||
|
|
|
@ -24,7 +24,14 @@ import qualified Data.Text as Text
|
||||||
import Lens.Family2 (Lens', Traversal')
|
import Lens.Family2 (Lens', Traversal')
|
||||||
import Lens.Family2.Unchecked (lens)
|
import Lens.Family2.Unchecked (lens)
|
||||||
|
|
||||||
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
|
import TensorFlow.Output
|
||||||
|
( Output
|
||||||
|
, ResourceHandle
|
||||||
|
, outputOp
|
||||||
|
, opUnrendered
|
||||||
|
, opAttr
|
||||||
|
, resourceHandleOutput
|
||||||
|
)
|
||||||
import TensorFlow.Types (TensorData(..), Attribute)
|
import TensorFlow.Types (TensorData(..), Attribute)
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
|
||||||
|
@ -61,6 +68,11 @@ tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
|
||||||
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
|
||||||
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
|
||||||
|
|
||||||
|
resourceHandleAttr :: Attribute attr
|
||||||
|
=> Text.Text -> Traversal' (ResourceHandle a) attr
|
||||||
|
resourceHandleAttr attr =
|
||||||
|
resourceHandleOutput . outputOp . opUnrendered . opAttr attr
|
||||||
|
|
||||||
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
|
||||||
-- Ref into Value. This behaves like a no-op.
|
-- Ref into Value. This behaves like a no-op.
|
||||||
value :: Tensor v a -> Tensor Value a
|
value :: Tensor v a -> Tensor Value a
|
||||||
|
|
Loading…
Reference in a new issue