mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
Add ToTensor class
This commit is contained in:
parent
d98e5d637c
commit
ddb4fe4f90
2 changed files with 14 additions and 2 deletions
|
@ -30,7 +30,7 @@ import TensorFlow.Core
|
||||||
import TensorFlow.Build (opDef)
|
import TensorFlow.Build (opDef)
|
||||||
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
import TensorFlow.BuildOp (buildInputs, pureOp, OpParams)
|
||||||
import TensorFlow.Output (opInputs, unNodeName)
|
import TensorFlow.Output (opInputs, unNodeName)
|
||||||
import TensorFlow.Tensor (Rendered(..), tensorNodeName)
|
import TensorFlow.Tensor (Rendered(..), ToTensor(..), tensorNodeName)
|
||||||
import TensorFlow.Types (tensorType)
|
import TensorFlow.Types (tensorType)
|
||||||
import qualified TensorFlow.GenOps.Core as CoreOps
|
import qualified TensorFlow.GenOps.Core as CoreOps
|
||||||
import TensorFlow.Ops (zeros)
|
import TensorFlow.Ops (zeros)
|
||||||
|
@ -40,6 +40,9 @@ newtype Variable a = Variable (Tensor Value ResourceHandle)
|
||||||
instance Rendered Variable where
|
instance Rendered Variable where
|
||||||
renderedOutput (Variable v) = renderedOutput v
|
renderedOutput (Variable v) = renderedOutput v
|
||||||
|
|
||||||
|
instance ToTensor Variable where
|
||||||
|
toTensor = readValue
|
||||||
|
|
||||||
-- | Creates a new, uninitialized variable.
|
-- | Creates a new, uninitialized variable.
|
||||||
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
variable :: (MonadBuild m, TensorType a) => Shape -> m (Variable a)
|
||||||
variable = variable' id
|
variable = variable' id
|
||||||
|
|
|
@ -38,7 +38,8 @@ import Proto.Tensorflow.Core.Framework.NodeDef (device)
|
||||||
import TensorFlow.Build
|
import TensorFlow.Build
|
||||||
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
|
import TensorFlow.Output (Output, NodeName, outputNodeName, Device(..))
|
||||||
import TensorFlow.Types
|
import TensorFlow.Types
|
||||||
( TensorData(..)
|
( TensorType
|
||||||
|
, TensorData(..)
|
||||||
, ListOf(..)
|
, ListOf(..)
|
||||||
)
|
)
|
||||||
import qualified TensorFlow.Internal.FFI as FFI
|
import qualified TensorFlow.Internal.FFI as FFI
|
||||||
|
@ -189,3 +190,11 @@ instance TensorKind Ref where
|
||||||
|
|
||||||
instance TensorKind Build where
|
instance TensorKind Build where
|
||||||
toBuild = id
|
toBuild = id
|
||||||
|
|
||||||
|
|
||||||
|
-- | Types which can be converted to `Tensor`.
|
||||||
|
class ToTensor t where
|
||||||
|
toTensor :: TensorType a => t a -> Tensor Build a
|
||||||
|
|
||||||
|
instance TensorKind v => ToTensor (Tensor v) where
|
||||||
|
toTensor = expr
|
||||||
|
|
Loading…
Reference in a new issue