-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- We use UndecidableInstances for type families with recursive definitions
-- like "\\".  Those instances will terminate since each equation unwraps one
-- cons cell of a type-level list.
{-# LANGUAGE UndecidableInstances #-}

module TensorFlow.Types
    ( TensorType(..)
    , TensorData(..)
    , Shape(..)
    , protoShape
    , Attribute(..)
    -- * Type constraints
    , OneOf
    , type (/=)
    -- ** Implementation of constraints
    , TypeError
    , ExcludedCase
    , TensorTypes
    , NoneOf
    , type (\\)
    , Delete
    , AllTensorTypes
    ) where

import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2.Unchecked (iso)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
    ( AttrValue(..)
    , AttrValue'ListValue(..)
    , b
    , f
    , i
    , s
    , list
    , type'
    , shape
    , tensor
    )
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
    ( TensorProto(..)
    , floatVal
    , doubleVal
    , intVal
    , stringVal
    , int64Val
    , stringVal
    , boolVal
    )
import Proto.Tensorflow.Core.Framework.TensorShape
    ( TensorShapeProto(..)
    , dim
    , size
    )
import Proto.Tensorflow.Core.Framework.Types (DataType(..))

import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI

-- | Data about a tensor that is encoded for the TensorFlow APIs.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }

-- | The class of scalar types supported by tensorflow.
class TensorType a where
    tensorType :: a -> DataType
    tensorRefType :: a -> DataType
    tensorVal :: Lens' TensorProto [a]
    -- | Decode the bytes of a TensorData into a Vector.
    decodeTensorData :: TensorData a -> V.Vector a
    -- | Encode a Vector into a TensorData.
    --
    -- The values should be in row major order, e.g.,
    --
    --   element 0:   index (0, ..., 0)
    --   element 1:   index (0, ..., 1)
    --   ...
    encodeTensorData :: Shape -> V.Vector a -> TensorData a

-- All types, besides ByteString, are encoded as simple arrays and we can use
-- Vector.Storable to encode/decode by type casting pointers.

-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> V.Vector a
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData

simpleEncode :: forall a . (TensorType a, Storable a)
             => Shape -> V.Vector a -> TensorData a
simpleEncode (Shape xs)
    = TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
  where
    dt = tensorType (undefined :: a)

instance TensorType Float where
    tensorType _ = DT_FLOAT
    tensorRefType _ = DT_FLOAT_REF
    tensorVal = floatVal
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Double where
    tensorType _ = DT_DOUBLE
    tensorRefType _ = DT_DOUBLE_REF
    tensorVal = doubleVal
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Int32 where
    tensorType _ = DT_INT32
    tensorRefType _ = DT_INT32_REF
    tensorVal = intVal
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Int64 where
    tensorType _ = DT_INT64
    tensorRefType _ = DT_INT64_REF
    tensorVal = int64Val
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

integral :: Integral a => Lens' [Int32] [a]
integral = iso (fmap fromIntegral) (fmap fromIntegral)

instance TensorType Word8 where
    tensorType _ = DT_UINT8
    tensorRefType _ = DT_UINT8_REF
    tensorVal = intVal . integral
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Word16 where
    tensorType _ = DT_UINT16
    tensorRefType _ = DT_UINT16_REF
    tensorVal = intVal . integral
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Int16 where
    tensorType _ = DT_INT16
    tensorRefType _ = DT_INT16_REF
    tensorVal = intVal . integral
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType Int8 where
    tensorType _ = DT_INT8
    tensorRefType _ = DT_INT8_REF
    tensorVal = intVal . integral
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType ByteString where
    tensorType _ = DT_STRING
    tensorRefType _ = DT_STRING_REF
    tensorVal = stringVal
    -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
    --   table offsets for each element :: [Word64]
    --   at each element offset:
    --     string length :: VarInt64
    --     string data   :: [Word8]
    -- TODO(fmayle): Benchmark these functions.
    decodeTensorData tensorData =
        either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
            if expected /= count
                then Left $ "decodeTensorData for ByteString count mismatch " ++
                            show (expected, count)
                else V.mapM decodeString (S.convert offsets)
      where
        expected = S.length offsets
        count = fromIntegral $ product $ FFI.tensorDataDimensions
                    $ unTensorData tensorData
        bytes = FFI.tensorDataBytes $ unTensorData tensorData
        offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
        dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
        decodeString :: Word64 -> Either String ByteString
        decodeString offset =
            let stringDataStart = B.drop (fromIntegral offset) dataBytes
            in Atto.eitherResult $ Atto.parse stringParser stringDataStart
        stringParser :: Atto.Parser ByteString
        stringParser = getVarInt >>= Atto.take . fromIntegral
    encodeTensorData (Shape xs) vec =
        TensorData $ FFI.TensorData xs dt byteVector
      where
        dt = tensorType (undefined :: ByteString)
        -- Add a string to an offset table and data blob.
        addString :: (Builder, Builder, Word64)
                  -> ByteString
                  -> (Builder, Builder, Word64)
        addString (table, strings, offset) str =
            ( table <> Builder.word64LE offset
            , strings <> lengthBytes <> Builder.byteString str
            , offset + lengthBytesLen + strLen
            )
          where
            strLen = fromIntegral $ B.length str
            lengthBytes = putVarInt $ fromIntegral $ B.length str
            lengthBytesLen =
                fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
        -- Encode all strings.
        (table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
        -- Concat offset table with data.
        bytes = table' <> strings'
        -- Convert to Vector Word8.
        byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes


instance TensorType Bool where
    tensorType _ = DT_BOOL
    tensorRefType _ = DT_BOOL_REF
    tensorVal = boolVal
    decodeTensorData = simpleDecode
    encodeTensorData = simpleEncode

instance TensorType (Complex Float) where
    tensorType _ = DT_COMPLEX64
    tensorRefType _ = DT_COMPLEX64
    tensorVal = error "TODO (Complex Float)"
    decodeTensorData = error "TODO (Complex Float)"
    encodeTensorData = error "TODO (Complex Float)"

instance TensorType (Complex Double) where
    tensorType _ = DT_COMPLEX128
    tensorRefType _ = DT_COMPLEX128
    tensorVal = error "TODO (Complex Double)"
    decodeTensorData = error "TODO (Complex Double)"
    encodeTensorData = error "TODO (Complex Double)"

-- | Shape (dimensions) of a tensor.
newtype Shape = Shape [Int64] deriving Show

instance IsList Shape where
    type Item Shape = Int64
    fromList = Shape . fromList
    toList (Shape ss) = toList ss

protoShape :: Lens' TensorShapeProto Shape
protoShape = iso protoToShape shapeToProto
  where
    protoToShape = Shape . fmap (view size) . view dim
    shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds


class Attribute a where
    attrLens :: Lens' AttrValue a

instance Attribute Float where
    attrLens = f

instance Attribute ByteString where
    attrLens = s

instance Attribute Int64 where
    attrLens = i

instance Attribute DataType where
    attrLens = type'

instance Attribute TensorProto where
    attrLens = tensor

instance Attribute Bool where
    attrLens = b

instance Attribute Shape where
    attrLens = shape . protoShape

-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
    attrLens = list

instance Attribute [DataType] where
    attrLens = list . type'

instance Attribute [Int64] where
    attrLens = list . i

-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
-- natural representation as a conjunction, i.e.,
--
-- @
--    a == Double || a == Float
-- @
--
-- into a disjunction like
--
-- @
--     a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
-- @
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
    = (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)

-- | A 'Constraint' checking that the input is a list of 'TensorType's.
-- Helps improve error messages when using 'OneOf'.
type family TensorTypes ts :: Constraint where
    TensorTypes '[] = ()
    TensorTypes (t ': ts) = (TensorType t, TensorTypes ts)

-- | A constraint checking that two types are different.
type family a /= b :: Constraint where
    a /= a = TypeError a ~ ExcludedCase
    a /= b = ()

-- | Helper types to produce a reasonable type error message when the Constraint
-- "a /= a" fails.
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
data TypeError a
data ExcludedCase

-- | An enumeration of all valid 'TensorType's.
type AllTensorTypes =
    -- NOTE: This list should be kept in sync with
    -- TensorFlow.OpGen.dtTypeToHaskell.
    -- TODO: Add support for Complex Float/Double.
    '[ Float
     , Double
     , Int8
     , Int16
     , Int32
     , Int64
     , Word8
     , Word16
     , ByteString
     , Bool
     ]

-- | Removes a type from the given list of types.
type family Delete a as where
    Delete a '[] = '[]
    Delete a (a ': as) = Delete a as
    Delete a (b ': as) = b ': Delete a as

-- | Takes the difference of two lists of types.
type family as \\ bs where
    as \\ '[] = as
    as \\ b ': bs = Delete b as \\ bs

-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
type family NoneOf ts a :: Constraint where
    NoneOf '[] a = ()
    NoneOf (t ': ts) a = (a /= t, NoneOf ts a)