{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
module TensorFlow.Tensor where
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal')
import Lens.Family2.Unchecked (lens)
import TensorFlow.Output
( Output
, ResourceHandle
, outputOp
, opUnrendered
, opAttr
, resourceHandleOutput
import TensorFlow.Types (TensorData(..), Attribute)
import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation.
-- The type parameter @a@ is the type of the elements in the 'Tensor'. The
-- parameter @v@ is either 'Value' or 'Ref', depending on whether the graph is
-- treating this op output as an immutable 'Value' or a stateful 'Ref' (e.g., a
-- variable). Note that a @Tensor Ref@ can be casted into a @Tensor Value@ via
-- 'value'.
data Tensor v a = Tensor (TensorKind v) Output
data Value
data Ref
-- | This class provides a runtime switch on whether a 'Tensor' should be
-- treated as a 'Value' or as a 'Ref'.
data TensorKind v where
ValueKind :: TensorKind Value
RefKind :: TensorKind Ref
tensorKind :: Lens' (Tensor v a) (TensorKind v)
tensorKind = lens (\(Tensor v _) -> v) (\(Tensor _ o) v -> Tensor v o)
tensorOutput :: Lens' (Tensor v a) Output
tensorOutput = lens (\(Tensor _ o) -> o) (\(Tensor v _) o -> Tensor v o)
-- TODO: Come up with a better API for handling attributes.
-- | Lens for the attributes of a tensor.
-- Only valid if the tensor has not yet been rendered. If the tensor has been
-- rendered, the traversal will be over nothing (nothing can be read or
-- written).
tensorAttr :: Attribute attr => Text.Text -> Traversal' (Tensor v a) attr
tensorAttr x = tensorOutput . outputOp . opUnrendered . opAttr x
resourceHandleAttr :: Attribute attr
=> Text.Text -> Traversal' (ResourceHandle a) attr
resourceHandleAttr attr =
resourceHandleOutput . outputOp . opUnrendered . opAttr attr
-- | Cast a 'Tensor *' into a 'Tensor Value'. Common usage is to cast a
-- Ref into Value. This behaves like a no-op.
value :: Tensor v a -> Tensor Value a
value (Tensor _ o) = Tensor ValueKind o
-- | A pair of a 'Tensor' and some data that should be fed into that 'Tensor'
-- when running the graph.
data Feed = Feed Output FFI.TensorData
-- | Create a 'Feed' for feeding the given data into a 'Tensor' when running
-- the graph.
-- Note that if a 'Tensor' is rendered, its identity may change; so feeding the
-- rendered 'Tensor' may be different than feeding the original 'Tensor'.
feed :: Tensor v a -> TensorData a -> Feed
feed (Tensor _ o) (TensorData td) = Feed o td
-- | Create a 'Tensor' for a given name. This can be used to reference nodes
-- in a 'GraphDef' that was loaded via 'addGraphDef'.
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack