-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE OverloadedStrings #-}

-- | Wrapping of TensorFlow attributes into Haskell entities.
module TensorFlow.OpGen.AttrVal
       (AttrDef
       , AttrCase(..)
       , AttrTemplate(..)
       , Template
       , attrDef
       , attrOriginal
       , attrTemplate
       , templateDefault
       , templateRestrictions
       ) where

import Data.Int (Int64)
import Data.Monoid ((<>))
import Lens.Family2 (Lens', (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue as AttrValue
import Proto.Tensorflow.Core.Framework.OpDef as OpDef
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Framework.TensorShape (TensorShapeProto)
import qualified Data.ByteString as B
import qualified Data.Text as Text

-- | Specifies the optional default value and a set of allowed values
-- for the given type.
data Template a = Template {
    _templateDefault      :: Maybe a
    -- ^ The default value (mandatory if unspecified)
  , _templateRestrictions :: [a]
    -- ^ The allowed set of values, empty if no restrictions
 }

templateDefault :: Lens' (Template a) (Maybe a)
templateDefault = lens _templateDefault (\g x -> g { _templateDefault = x })

templateRestrictions :: Lens' (Template a) [a]
templateRestrictions = lens _templateRestrictions
                            (\g x -> g { _templateRestrictions = x })

data UnusedTensor

data AttrCase f
  = AttrBytes (f B.ByteString)          -- bytes s = 2; // "string"
  | AttrInt64 (f Int64)                 -- int64 i = 3; // "int"
  | AttrFloat (f Float)                 -- float f = 4; // "float"
  | AttrBool  (f Bool)                  -- bool b = 5;  // "bool"
  | AttrType  (f DataType)              -- type = 6; // "type"
    -- To be translated into TensorFlow.Types.Shape before use.
    -- Leaving as a proto to reduce dependencies.
  | AttrShape (f TensorShapeProto)      -- shape = 7; // "shape"

-- | Type-reified representation of TensorFlow AttrDef.
-- Initially limited to just the types in Op descriptors.
data AttrTemplate
  = AttrSingle (AttrCase Template)
  | AttrList (AttrCase [])
  | AttrTensor UnusedTensor         -- tensor = 8; // "tensor"

data AttrDef = AttrDef {
    _attrOriginal :: OpDef'AttrDef -- ^ the proto this value was created from
  , _attrTemplate :: AttrTemplate  -- ^ the type of the attribute
  }

attrTemplate :: Lens' AttrDef AttrTemplate
attrTemplate = lens _attrTemplate (\g x -> g { _attrTemplate = x })

attrOriginal :: Lens' AttrDef OpDef'AttrDef
attrOriginal = lens _attrOriginal (\g x -> g { _attrOriginal = x })

attrDef :: OpDef'AttrDef -> AttrDef
attrDef a = AttrDef a
                  $ translate (a^.OpDef.type')
                              (a^.OpDef.defaultValue)
                              (a^.allowedValues)

-- | Converts the given AttrValue with the type given by the string
-- into the AttrVal if the type is known.
translate :: Text.Text  -- ^ one of the TensorFlow type strings
          -> AttrValue  -- ^ default value
          -> AttrValue  -- ^ allowed values
          -> AttrTemplate
translate t defaults allowed
  | t == "string" = makeVal AttrBytes maybe's s
  | t == "int" = makeVal AttrInt64 maybe'i i
  | t == "float" = makeVal AttrFloat maybe'f f
  | t == "bool" = makeVal AttrBool maybe'b b
  | t == "type" = makeVal AttrType AttrValue.maybe'type' AttrValue.type'
  | t == "shape" = makeVal AttrShape maybe'shape shape
  | t == "tensor" = AttrTensor $ error "tensor is unimplemented"
  | t == "list(string)" = makeList AttrBytes $ list.s
  | t == "list(int)" = makeList AttrInt64 $ list.i
  | t == "list(float)" = makeList AttrFloat $ list.f
  | t == "list(bool)" = makeList AttrBool $ list.b
  | t == "list(type)" = makeList AttrType $ list.AttrValue.type'
  | t == "list(shape)" = makeList AttrShape $ list.shape
  | t == "list(tensor)" = AttrTensor $ error "list(tensor) is unimplemented"
  | t == "func" = AttrTensor $ error "func is unimplemented"
  | otherwise = error $ show ("Unknown attribute type " <> t) ++
                        "," ++ show defaults ++
                        "," ++ show allowed
  where makeVal c x y = AttrSingle $ c $
                        Template (defaults^.x) (allowed^.list.y)
        makeList c y  = AttrList $ c $ defaults^.y