diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index b5fd115..32246d4 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -262,7 +262,7 @@ instance TensorDataType S.Vector Bool where where fromBool x = if x then 1 else 0 :: Word8 -instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a) +instance {-# OVERLAPPABLE #-} (TensorType a, Storable a, TensorDataType S.Vector a) => TensorDataType V.Vector a where decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a) @@ -329,7 +329,7 @@ newtype Scalar a = Scalar {unScalar :: a} deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat, RealFrac, IsString) -instance TensorDataType V.Vector a => TensorDataType Scalar a where +instance (TensorType a, TensorDataType V.Vector a) => TensorDataType Scalar a where decodeTensorData = Scalar . headFromSingleton . decodeTensorData encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])