[mlir][spirv] Change numeric constant's bit-extension decision to be based on type

Integer constants with bit width less than a word (e.g., i8, i16)
should be bit extended based on its type to be SPIR-V spec-compliant.
Previously, the decision was based on the most significant bit of the
value which ignores the signless semantics and causes problems when
interfacing with SPIR-V tools.

Dealing with numeric literals: the SPIR-V spec says, "If a numeric
type’s bit width is less than 32-bits, the value appears in the
low-order bits of the word, and the high-order bits must be 0 for
a floating-point type or integer type with Signedness of 0, or sign
extended for an integer type with a Signedness of 1 (similarly for the
remaining bits of widths larger than 32 bits but not a multiple of 32
bits)."

Therefore, signless integers (e.g., i8, i16) and unsigned integers
should be 0-extended, and signed integers (e.g., si8, si16) should be
sign-extended.

Patch By: mshahneo
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D151767
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index f32f6e8..1ef8ff0 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -846,8 +846,7 @@
   auto resultID = getNextID();
   APInt value = intAttr.getValue();
   unsigned bitwidth = value.getBitWidth();
-  bool isSigned = value.isSignedIntN(bitwidth);
-
+  bool isSigned = intAttr.getType().isSignedInteger();
   auto opcode =
       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
 
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 7e3ae2a..f395021 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -264,4 +264,17 @@
     %0 = spirv.Constant dense<1> : tensor<2x2x3xi32> : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24>
     spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24>
   }
+
+  // CHECK-LABEL: @signless_int_const_bit_extension
+  spirv.func @signless_int_const_bit_extension() -> (i16) "None" {
+    // CHECK: spirv.Constant -1 : i16
+    %signless_minus_one = spirv.Constant -1 : i16
+    spirv.ReturnValue %signless_minus_one : i16
+  }
+  // CHECK-LABEL: @signed_int_const_bit_extension
+  spirv.func @signed_int_const_bit_extension() -> (si16) "None" {
+    // CHECK: spirv.Constant -1 : si16
+    %signed_minus_one = spirv.Constant -1 : si16
+    spirv.ReturnValue %signed_minus_one : si16
+  }
 }
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index f7a1db0..56a98cc 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -76,6 +76,27 @@
         builder.getStringAttr(name), nullptr);
   }
 
+  // Inserts an Integer or a Vector of Integers constant of value 'val'.
+  spirv::ConstantOp AddConstInt(Type type, APInt val) {
+    OpBuilder builder(module->getRegion());
+    auto loc = UnknownLoc::get(&context);
+
+    if (auto intType = dyn_cast<IntegerType>(type)) {
+      return builder.create<spirv::ConstantOp>(
+          loc, type, builder.getIntegerAttr(type, val));
+    }
+    if (auto vectorType = dyn_cast<VectorType>(type)) {
+      Type elemType = vectorType.getElementType();
+      if (auto intType = dyn_cast<IntegerType>(elemType)) {
+        return builder.create<spirv::ConstantOp>(
+            loc, type,
+            DenseElementsAttr::get(vectorType,
+                                   IntegerAttr::get(elemType, val).getValue()));
+      }
+    }
+    llvm_unreachable("unimplemented types for AddConstInt()");
+  }
+
   /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
   /// Returns true to interrupt.
   using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
@@ -149,6 +170,34 @@
   EXPECT_EQ(count, 1u);
 }
 
+TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
+
+  auto signlessInt16Type =
+      IntegerType::get(&context, 16, IntegerType::Signless);
+  auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
+  // Check the bit extension of same value under different signedness semantics.
+  APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+                            signlessInt16Type.getSignedness());
+  APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
+                          signedInt16Type.getSignedness());
+
+  AddConstInt(signlessInt16Type, signlessIntConstVal);
+  AddConstInt(signedInt16Type, signedIntConstVal);
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
+  auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
+           operands[2] == 65535;
+  };
+  EXPECT_TRUE(scanInstruction(hasSignlessVal));
+
+  auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
+           operands[2] == 4294967295;
+  };
+  EXPECT_TRUE(scanInstruction(hasSignedVal));
+}
+
 TEST_F(SerializationTest, ContainsSymbolName) {
   auto structType = getFloatStructType();
   addGlobalVar(structType, "var0");