diff --git a/README.md b/README.md index f830c3f..c5a9650 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,37 @@ No exportable Nix package will appear, but local development is possible. # Related Projects +## Statically validated tensor shapes + https://github.com/helq/tensorflow-haskell-deptyped is experimenting with using dependent types to statically validate tensor shapes. May be merged with this repository in the future. +Example: + +```haskell +{-# LANGUAGE DataKinds, ScopedTypeVariables #-} + +import Data.Maybe (fromJust) +import Data.Vector.Sized (Vector, fromList) +import TensorFlow.DepTyped + +test :: IO (Vector 8 Float) +test = runSession $ do + (x :: Placeholder "x" '[4,3] Float) <- placeholder + + let elems1 = fromJust $ fromList [1,2,3,4,1,2] + elems2 = fromJust $ fromList [5,6,7,8] + (w :: Tensor '[3,2] '[] Build Float) = constant elems1 + (b :: Tensor '[4,1] '[] Build Float) = constant elems2 + y = (x `matMul` w) `add` b -- y shape: [4,2] (b shape is [4.1] but `add` broadcasts it to [4,2]) + + let (inputX :: TensorData "x" [4,3] Float) = + encodeTensorData . fromJust $ fromList [1,2,3,4,1,0,7,9,5,3,5,4] + + runWithFeeds (feed x inputX :~~ NilFeedList) y + +main :: IO () +main = test >>= print +``` + # License This project is licensed under the terms of the [Apache 2.0 license](LICENSE).