[MLIR] Fix shape inference in toy tutorial

The implementation of shape inference in the toy tutorial did not conform to the correct algorithmic description.
The result was only correct because all operations appear to be processed in sequence.

Differential Revision: https://reviews.llvm.org/D77382
diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index 107c804..296bec0 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index 107c804..296bec0 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
index 107c804..296bec0 100644
--- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
index 107c804..296bec0 100644
--- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
@@ -62,7 +62,7 @@
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
+      auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
       if (nextop == opWorklist.end())
         break;
 
@@ -88,6 +88,14 @@
     }
   }
 
+  /// A utility method that returns if the given operation has all of its
+  /// operands inferred.
+  static bool allOperandsInferred(Operation *op) {
+    return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
+      return operandType.isa<RankedTensorType>();
+    });
+  }
+
   /// A utility method that returns if the given operation has a dynamically
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {