module TensorFlow.BuildOp
( OpResult
, BuildOp
, buildOp
, buildListOp
, eqLengthGuard
)
where
import Control.Monad (replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put)
import Data.Int (Int64)
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT Op (State ResultState)
class OpResult a where
toResult :: Result a
instance (OpResult a1, OpResult a2) => OpResult (a1, a2) where
toResult = (,) <$> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3) => OpResult (a1, a2, a3) where
toResult = (,,) <$> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4)
=> OpResult (a1, a2, a3, a4) where
toResult = (,,,) <$> toResult <*> toResult <*> toResult <*> toResult
instance (OpResult a1, OpResult a2, OpResult a3, OpResult a4, OpResult a5)
=> OpResult (a1, a2, a3, a4, a5) where
toResult = (,,,,) <$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
instance ( OpResult a1
, OpResult a2
, OpResult a3
, OpResult a4
, OpResult a5
, OpResult a6
)
=> OpResult (a1, a2, a3, a4, a5, a6) where
toResult = (,,,,,)
<$> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
<*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = do
o <- ask
ResultState i ns <- get
put $! ResultState (i+1) ns
return $! Tensor v $ output i o
instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind
instance OpResult (Tensor Ref a) where
toResult = tensorResult RefKind
instance OpResult ControlNode where
toResult = ControlNode <$> ask
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in toResult. " ++
"Likely misuse of buildListOp."
(n : rest) -> do
put $! ResultState i rest
replicateM (fromIntegral n) toResult
runResult :: OpResult a => [Int64] -> Op -> a
runResult ns o =
case runState (runReaderT toResult o) (ResultState 0 ns) of
(x, ResultState _ []) -> x
(_, ns') -> error $ "Ununsed length in runResult attributes: " ++
show (ns, ns')
pureResult :: OpResult a => [Int64] -> OpDef -> [Output] -> a
pureResult ns o ts = runResult ns $ Unrendered $ addReversedInputs o ts
buildResult :: OpResult a => [Int64] -> OpDef -> [Output] -> Build a
buildResult ns o ts
= runResult ns . Rendered <$> addNewOp (addReversedInputs o ts)
addReversedInputs :: OpDef -> [Output] -> OpDef
addReversedInputs o ts = o & opInputs <>~ reverse ts
class BuildOp f where
buildOp' :: [Int64]
-> OpDef
-> [Output]
-> f
buildOp :: BuildOp f => OpDef -> f
buildOp o = buildOp' [] o []
buildListOp :: BuildOp f => [Int64]
-> OpDef -> f
buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
instance (OpResult t1, OpResult t2) => BuildOp (t1, t2) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3) => BuildOp (t1, t2, t3) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4)
=> BuildOp (t1, t2, t3, t4) where
buildOp' = pureResult
instance (OpResult t1, OpResult t2, OpResult t3, OpResult t4, OpResult t5)
=> BuildOp (t1, t2, t3, t4, t5) where
buildOp' = pureResult
instance ( OpResult t1
, OpResult t2
, OpResult t3
, OpResult t4
, OpResult t5
, OpResult t6
)
=> BuildOp (t1, t2, t3, t4, t5, t6) where
buildOp' = pureResult
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = all eachOk
where
eachOk (_, []) = True
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)