Add ToTensor class

This commit is contained in:
fkm3 2017-05-15 22:09:21 -07:00
parent d98e5d637c
commit ddb4fe4f90
2 changed files with 14 additions and 2 deletions

View File

@ -30,7 +30,7 @@ import TensorFlow.Core
import TensorFlow.Build (opDef)
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
import TensorFlow.Output (opInputs, unNodeName)
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
import TensorFlow.Types (tensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Ops (zeros)
@ -40,6 +40,9 @@ newtype Variable a = Variable (Tensor Value ResourceHandle)
instance Rendered Variable where
renderedOutput (Variable v) = renderedOutput v
instance ToTensor Variable where
toTensor = readValue
-- | Creates a new, uninitialized variable.
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
variable = variable' id

View File

@ -38,7 +38,8 @@ import Proto.Tensorflow.Core.Framework.NodeDef (device)
import TensorFlow.Build
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
import TensorFlow.Types
( TensorData(..)
( TensorType
, TensorData(..)
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
@ -189,3 +190,11 @@ instance TensorKind Ref where
instance TensorKind Build where
toBuild = id
-- | Types which can be converted to `Tensor`.
class ToTensor t where
toTensor :: TensorType a => t a -> Tensor Build a
instance TensorKind v => ToTensor (Tensor v) where
toTensor = expr