mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add ToTensor class
This commit is contained in:
parent
d98e5d637c
commit
ddb4fe4f90
2 changed files with 14 additions and 2 deletions
|
@ -30,7 +30,7 @@ import TensorFlow.Core
|
|||
import TensorFlow.Build (opDef)
|
||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||
import TensorFlow.Output (opInputs, unNodeName)
|
||||
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
|
||||
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
|
||||
import TensorFlow.Types (tensorType)
|
||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||
import TensorFlow.Ops (zeros)
|
||||
|
@ -40,6 +40,9 @@ newtype Variable a = Variable (Tensor Value ResourceHandle)
|
|||
instance Rendered Variable where
|
||||
renderedOutput (Variable v) = renderedOutput v
|
||||
|
||||
instance ToTensor Variable where
|
||||
toTensor = readValue
|
||||
|
||||
-- | Creates a new, uninitialized variable.
|
||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||
variable = variable' id
|
||||
|
|
|
@ -38,7 +38,8 @@ import Proto.Tensorflow.Core.Framework.NodeDef (device)
|
|||
import TensorFlow.Build
|
||||
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
|
||||
import TensorFlow.Types
|
||||
( TensorData(..)
|
||||
( TensorType
|
||||
, TensorData(..)
|
||||
, ListOf(..)
|
||||
)
|
||||
import qualified TensorFlow.Internal.FFI as FFI
|
||||
|
@ -189,3 +190,11 @@ instance TensorKind Ref where
|
|||
|
||||
instance TensorKind Build where
|
||||
toBuild = id
|
||||
|
||||
|
||||
-- | Types which can be converted to `Tensor`.
|
||||
class ToTensor t where
|
||||
toTensor :: TensorType a => t a -> Tensor Build a
|
||||
|
||||
instance TensorKind v => ToTensor (Tensor v) where
|
||||
toTensor = expr
|
||||
|
|
Loading…
Reference in a new issue