module TensorFlow.OpGen.AttrVal
(AttrDef
, AttrCase(..)
, AttrTemplate(..)
, Template
, attrDef
, attrOriginal
, attrTemplate
, templateDefault
, templateRestrictions
) where
import Data.Int (Int64)
import Data.Monoid ((<>))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
import qualified Data.ByteString as B
import qualified Data.Text as Text
data Template a = Template {
_templateDefault :: Maybe a
, _templateRestrictions :: [a]
}
templateDefault :: Lens' (Template a) (Maybe a)
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })
templateRestrictions :: Lens' (Template a) [a]
templateRestrictions = lens _templateRestrictions
(\g x -> g { _templateRestrictions = x })
data UnusedTensor
data AttrCase f
= AttrBytes (f B.ByteString)
| AttrInt64 (f Int64)
| AttrFloat (f Float)
| AttrBool (f Bool)
| AttrType (f DataType)
| AttrShape (f TensorShapeProto)
data AttrTemplate
= AttrSingle (AttrCase Template)
| AttrList (AttrCase [])
| AttrTensor UnusedTensor
data AttrDef = AttrDef {
_attrOriginal :: OpDef'AttrDef
, _attrTemplate :: AttrTemplate
}
attrTemplate :: Lens' AttrDef AttrTemplate
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })
attrOriginal :: Lens' AttrDef OpDef'AttrDef
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })
attrDef :: OpDef'AttrDef -> AttrDef
attrDef a = AttrDef a
$ translate (a^.OpDef.type')
(a^.OpDef.defaultValue)
(a^.allowedValues)
translate :: Text.Text
-> AttrValue
-> AttrValue
-> AttrTemplate
translate t defaults allowed
| t == "string" = makeVal AttrBytes maybe's s
| t == "int" = makeVal AttrInt64 maybe'i i
| t == "float" = makeVal AttrFloat maybe'f f
| t == "bool" = makeVal AttrBool maybe'b b
| t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
| t == "shape" = makeVal AttrShape maybe'shape shape
| t == "tensor" = AttrTensor $ error "tensor is unimplemented"
| t == "list(string)" = makeList AttrBytes $ list.s
| t == "list(int)" = makeList AttrInt64 $ list.i
| t == "list(float)" = makeList AttrFloat $ list.f
| t == "list(bool)" = makeList AttrBool $ list.b
| t == "list(type)" = makeList AttrType $ list.AttrValue.type'
| t == "list(shape)" = makeList AttrShape $ list.shape
| t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
| t == "func" = AttrTensor $ error "func is unimplemented"
| otherwise = error $ show ("Unknown attribute type " <> t) ++
"," ++ show defaults ++
"," ++ show allowed
where makeVal c x y = AttrSingle $ c $
Template (defaults^.x) (allowed^.list.y)
makeList c y = AttrList $ c $ defaults^.y