| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir import ir |
| from mlir.dialects import transform |
| from mlir.dialects.transform import tune, debug |
| |
| |
| def run(f): |
| print("\n// TEST:", f.__name__) |
| with ir.Context(), ir.Location.unknown(): |
| module = ir.Module.create() |
| with ir.InsertionPoint(module.body): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.AnyOpType.get(), |
| ) |
| with ir.InsertionPoint(sequence.body): |
| f(sequence.bodyTarget) |
| transform.YieldOp() |
| print(module) |
| return f |
| |
| |
| # CHECK-LABEL: TEST: testKnobOp |
| @run |
| def testKnobOp(target): |
| any_param = transform.AnyParamType.get() |
| |
| # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param |
| heads_or_tails = tune.KnobOp( |
| result=any_param, name=ir.StringAttr.get("coin"), options=[True, False] |
| ) |
| # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param |
| tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()]) |
| # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param |
| tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32]) |
| # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param |
| tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0]) |
| |
| # CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]] |
| debug.emit_param_as_remark(heads_or_tails) |
| |
| # CHECK: %[[HEADS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param |
| heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True) |
| # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param |
| tune.KnobOp( |
| any_param, |
| name="animal", |
| options=["cat", "dog", ir.UnitAttr.get()], |
| selected="dog", |
| ) |
| # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param |
| tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8) |
| # CHECK: transform.tune.knob<"magic_value"> = 2.500000e+00 : f64 from options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param |
| tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0], selected=2.5) |
| |
| # CHECK: transform.debug.emit_param_as_remark %[[HEADS]] |
| debug.emit_param_as_remark(heads) |
| |
| # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param |
| # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified. |
| i64 = ir.IntegerType.get_signless(64) |
| tune.knob( |
| any_param, |
| "range_as_a_dict", |
| ir.DictAttr.get( |
| { |
| "start": ir.IntegerAttr.get(i64, 2), |
| "stop": ir.IntegerAttr.get(i64, 16), |
| "step": ir.IntegerAttr.get(i64, 2), |
| } |
| ), |
| selected=4, |
| ) |
| |
| |
| # CHECK-LABEL: TEST: testAlternativesOp |
| @run |
| def testAlternativesOp(target): |
| any_param = transform.AnyParamType.get() |
| |
| # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param { |
| left_or_right = tune.AlternativesOp( |
| [transform.AnyParamType.get()], "left_or_right", 2 |
| ) |
| idx_for_left, idx_for_right = 0, 1 |
| with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]): |
| # CHECK: %[[C0:.*]] = transform.param.constant 0 |
| i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) |
| c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) |
| # CHECK: transform.yield %[[C0]] |
| transform.yield_(c0) |
| # CHECK-NEXT: }, { |
| with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]): |
| # CHECK: %[[C1:.*]] = transform.param.constant 1 |
| i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) |
| c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) |
| # CHECK: transform.yield %[[C1]] |
| transform.yield_(c1) |
| # CHECK-NEXT: } |
| outcome_of_left_or_right_decision = left_or_right.results[0] |
| |
| # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param { |
| fork_in_the_road = tune.AlternativesOp( |
| [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0 |
| ) |
| with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]): |
| # CHECK: %[[C0:.*]] = transform.param.constant 0 |
| i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) |
| c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) |
| # CHECK: transform.yield %[[C0]] |
| transform.yield_(c0) |
| # CHECK-NEXT: }, { |
| with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]): |
| # CHECK: %[[C1:.*]] = transform.param.constant 1 |
| i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) |
| c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) |
| # CHECK: transform.yield %[[C1]] |
| transform.yield_(c1) |
| # CHECK-NEXT: } |
| |
| # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param { |
| left_or_right_as_before = tune.AlternativesOp( |
| [], |
| "left_or_right_as_before", |
| 2, |
| selected_region=outcome_of_left_or_right_decision, |
| ) |
| with ir.InsertionPoint( |
| left_or_right_as_before.alternatives[idx_for_left].blocks[0] |
| ): |
| # CHECK: transform.param.constant 1337 |
| i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337) |
| c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337) |
| # CHECK: transform.debug.emit_param_as_remark |
| debug.emit_param_as_remark(c1337) |
| transform.yield_([]) |
| # CHECK-NEXT: }, { |
| with ir.InsertionPoint( |
| left_or_right_as_before.alternatives[idx_for_right].blocks[0] |
| ): |
| # CHECK: transform.param.constant 42 |
| i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) |
| c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42) |
| # CHECK: transform.debug.emit_param_as_remark |
| debug.emit_param_as_remark(c42) |
| transform.yield_([]) |
| # CHECK-NEXT: } |