From a02ecdee1578be16b94fcb75158a704fd16951a1 Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Thu, 6 Jun 2024 11:50:46 -0700 Subject: [PATCH] Adds support for parameterized decimal cast (#1483) --- .../partiql/planner/internal/PartiQLHeader.kt | 32 ++++++++ .../internal/transforms/RexConverter.kt | 30 +++++-- .../planner/internal/typer/FnResolver.kt | 33 +++++++- .../planner/internal/typer/PlanTyper.kt | 27 +++++++ .../planner/internal/typer/TypeLattice.kt | 2 + .../internal/typer/PlanTyperTestsPorted.kt | 81 +++++++++++++++++++ 6 files changed, 198 insertions(+), 7 deletions(-) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt index ab50c6ac09..2d7af3545b 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt @@ -9,9 +9,14 @@ import org.partiql.value.PartiQLValueType.BOOL import org.partiql.value.PartiQLValueType.CHAR import org.partiql.value.PartiQLValueType.DATE import org.partiql.value.PartiQLValueType.DECIMAL +import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY +import org.partiql.value.PartiQLValueType.FLOAT32 +import org.partiql.value.PartiQLValueType.FLOAT64 import org.partiql.value.PartiQLValueType.INT +import org.partiql.value.PartiQLValueType.INT16 import org.partiql.value.PartiQLValueType.INT32 import org.partiql.value.PartiQLValueType.INT64 +import org.partiql.value.PartiQLValueType.INT8 import org.partiql.value.PartiQLValueType.MISSING import org.partiql.value.PartiQLValueType.NULL import org.partiql.value.PartiQLValueType.STRING @@ -63,6 +68,7 @@ internal object PartiQLHeader : Header() { mod(), concat(), bitwiseAnd(), + castAsParameterizedDecimal(), // explicit casts (aka NOT coercions from TypeLattice). ).flatten() /** @@ -460,6 +466,32 @@ internal object PartiQLHeader : Header() { ) } + private fun castAsParameterizedDecimal(): List = listOf( + BOOL, + INT8, + INT16, + INT32, + INT64, + INT, + DECIMAL, + DECIMAL_ARBITRARY, + FLOAT32, + FLOAT64, + STRING, + ).map { value -> + FunctionSignature.Scalar( + name = "cast_decimal", + returns = DECIMAL, + parameters = listOf( + FunctionParameter("value", value), + FunctionParameter("precision", INT32), + FunctionParameter("scale", INT32), + ), + isNullable = false, + isNullCall = true, + ) + } + // SUBSTRING (expression, start[, length]?) // SUBSTRINGG(expression from start [FOR length]? ) private fun substring(): List = types.text.map { t -> diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index ec696d417c..44e749bd29 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -99,7 +99,11 @@ internal object RexConverter { * @param ctx * @return */ - private fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { + private fun visitExprCoerce( + node: Expr, + ctx: Env, + coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR, + ): Rex { val rex = super.visitExpr(node, ctx) return when (rex.op is Rex.Op.Select) { true -> rex(StaticType.ANY, rexOpSubquery(rex.op, coercion)) @@ -188,7 +192,10 @@ internal object RexConverter { when (identifierSteps.size) { 0 -> root to node.steps else -> { - val newRoot = rex(StaticType.ANY, rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope)) + val newRoot = rex( + StaticType.ANY, + rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope) + ) val newSteps = node.steps.subList(identifierSteps.size, node.steps.size) newRoot to newSteps } @@ -219,7 +226,10 @@ internal object RexConverter { is Expr.Path.Step.Symbol -> { val identifier = AstToPlan.convert(step.symbol) when (identifier.caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> rexOpPathKey(current, rexString(identifier.symbol)) + Identifier.CaseSensitivity.SENSITIVE -> rexOpPathKey( + current, + rexString(identifier.symbol) + ) Identifier.CaseSensitivity.INSENSITIVE -> rexOpPathSymbol(current, identifier.symbol) } } @@ -516,7 +526,7 @@ internal object RexConverter { TODO("SQL Special Form EXTRACT") } - // TODO: Ignoring type parameter now + // TODO: Ignoring type parameters (EXCEPT DECIMAL) now override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex { val type = node.asType val arg0 = visitExprCoerce(node.value, ctx) @@ -532,7 +542,17 @@ internal object RexConverter { is Type.Real -> TODO("Static Type does not have REAL type") is Type.Float32 -> TODO("Static Type does not have FLOAT32 type") is Type.Float64 -> rex(StaticType.FLOAT, call("cast_float64", arg0)) - is Type.Decimal -> rex(StaticType.DECIMAL, call("cast_decimal", arg0)) + is Type.Decimal -> { + if (type.precision != null) { + // CONSTRAINED — cast_decimal(arg, precision, scale) + val p = rex(StaticType.INT4, rexOpLit(int32Value(type.precision))) + val s = rex(StaticType.INT4, rexOpLit(int32Value(type.scale ?: 0))) + rex(StaticType.DECIMAL, call("cast_decimal", arg0, p, s)) + } else { + // UNCONSTRAINED — cast_decimal(arg) + rex(StaticType.DECIMAL, call("cast_decimal", arg0)) + } + } is Type.Numeric -> rex(StaticType.DECIMAL, call("cast_numeric", arg0)) is Type.Char -> rex(StaticType.CHAR, call("cast_char", arg0)) is Type.Varchar -> rex(StaticType.STRING, call("cast_varchar", arg0)) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/FnResolver.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/FnResolver.kt index d595a051b2..c3deaa10c9 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/FnResolver.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/FnResolver.kt @@ -83,7 +83,7 @@ internal sealed class FnMatch { * @property candidates an ordered list of potentially applicable functions to dispatch dynamically. */ public data class Dynamic( - public val candidates: List> + public val candidates: List>, ) : FnMatch() public data class Error( @@ -354,7 +354,36 @@ internal class FnResolver(private val header: Header) { * * But what about parameterized types? Are the parameters dropped in casts, or do parameters become arguments? */ - private fun castName(type: PartiQLValueType) = "cast_${type.name.lowercase()}" + private fun castName(type: PartiQLValueType): String = when (type) { + ANY -> "cast_any" // TODO remove, only added for backwards compatibility in next release. + BOOL -> "cast_bool" + INT8 -> "cast_int8" + INT16 -> "cast_int16" + INT32 -> "cast_int32" + INT64 -> "cast_int64" + INT -> "cast_int" + DECIMAL -> "cast_decimal" + DECIMAL_ARBITRARY -> "cast_decimal" + FLOAT32 -> "cast_float32" + FLOAT64 -> "cast_float64" + CHAR -> "cast_char" + STRING -> "cast_string" + SYMBOL -> "cast_symbol" + BINARY -> "cast_binary" + BYTE -> "cast_byte" + BLOB -> "cast_blob" + CLOB -> "cast_clob" + DATE -> "cast_date" + TIME -> "cast_time" + TIMESTAMP -> "cast_timestamp" + INTERVAL -> "cast_interval" + BAG -> "cast_bag" + LIST -> "cast_list" + SEXP -> "cast_sexp" + STRUCT -> "cast_struct" + PartiQLValueType.NULL -> "cast_null" // TODO remove, only added for backwards compatibility in next release. + PartiQLValueType.MISSING -> "cast_missing" // TODO remove, only added for backwards compatibility in next release. + } internal fun cast(operand: PartiQLValueType, target: PartiQLValueType) = FunctionSignature.Scalar( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 71a38f6c37..4d45573649 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -78,6 +78,7 @@ import org.partiql.types.AnyType import org.partiql.types.BagType import org.partiql.types.BoolType import org.partiql.types.CollectionType +import org.partiql.types.DecimalType import org.partiql.types.IntType import org.partiql.types.ListType import org.partiql.types.SexpType @@ -91,6 +92,7 @@ import org.partiql.types.StructType import org.partiql.types.TupleConstraint import org.partiql.types.function.FunctionSignature import org.partiql.value.BoolValue +import org.partiql.value.Int32Value import org.partiql.value.MissingValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.TextValue @@ -634,12 +636,37 @@ internal class PlanTyper( } } + // TODO we have to pull out decimal type parameters here because V0 drops the type in CAST. + if (newFn.signature.name == "cast_decimal" && newFn.signature.parameters.size == 3) { + val p = getIntOrErr(newArgs[1].op) + val s = getIntOrErr(newArgs[2].op) + val returns = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(p, s)) + val op = rexOpCallStatic(newFn, newArgs) + return rex(returns, op) + } + // Type return val returns = newFn.signature.returns val op = rexOpCallStatic(newFn, newArgs) return rex(returns.toStaticType().flatten(), op) } + /** + * For `cast_decimal(v, precision, scale)` we make the precision and scale literal 32-bit integers. + */ + private fun getIntOrErr(op: Rex.Op): Int { + if (op !is Rex.Op.Lit) { + error("Unrecoverable, expected Rex.Op.Lit found ${op::class}. This should be unreachable.") + } + if (op.value !is Int32Value) { + error("Unrecoverable, expected Int32Value found ${op.value::class}. This should be unreachable.") + } + if (op.value.value == null) { + error("Int32Value cannot be null. This should be unreachable.") + } + return op.value.value!! + } + override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { // Rewrite CASE-WHEN branches val oldBranches = node.branches.toTypedArray() diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeLattice.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeLattice.kt index a6ade5f7da..9b85bfd8aa 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeLattice.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeLattice.kt @@ -334,6 +334,8 @@ internal class TypeLattice private constructor( INT32 to unsafe(), INT64 to unsafe(), INT to unsafe(), + DECIMAL to unsafe(), + DECIMAL_ARBITRARY to unsafe(), STRING to coercion(), SYMBOL to explicit(), CLOB to coercion(), diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 700404339e..8f3f75c82d 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -32,10 +32,12 @@ import org.partiql.plugins.memory.MemoryConnector import org.partiql.spi.connector.ConnectorMetadata import org.partiql.types.AnyType import org.partiql.types.BagType +import org.partiql.types.DecimalType import org.partiql.types.ListType import org.partiql.types.SexpType import org.partiql.types.StaticType import org.partiql.types.StaticType.Companion.ANY +import org.partiql.types.StaticType.Companion.DECIMAL import org.partiql.types.StaticType.Companion.INT import org.partiql.types.StaticType.Companion.INT4 import org.partiql.types.StaticType.Companion.INT8 @@ -287,6 +289,80 @@ class PlanTyperTestsPorted { @JvmStatic fun structs() = listOf() + @JvmStatic + fun decimalCastCases() = listOf( + SuccessTestCase( + name = "cast decimal", + query = "CAST(1 AS DECIMAL)", + expected = StaticType.DECIMAL, + ), + SuccessTestCase( + name = "cast decimal(1)", + query = "CAST(1 AS DECIMAL(1))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)), + ), + SuccessTestCase( + name = "cast decimal(1,0)", + query = "CAST(1 AS DECIMAL(1,0))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)), + ), + SuccessTestCase( + name = "cast decimal(1,1)", + query = "CAST(1 AS DECIMAL(1,1))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 1)), + ), + SuccessTestCase( + name = "cast decimal(38)", + query = "CAST(1 AS DECIMAL(38))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)), + ), + SuccessTestCase( + name = "cast decimal(38,0)", + query = "CAST(1 AS DECIMAL(38,0))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)), + ), + SuccessTestCase( + name = "cast decimal(38,38)", + query = "CAST(1 AS DECIMAL(38,38))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 38)), + ), + SuccessTestCase( + name = "cast decimal string", + query = "CAST('1' AS DECIMAL)", + expected = StaticType.DECIMAL, + ), + SuccessTestCase( + name = "cast decimal(1) string", + query = "CAST('1' AS DECIMAL(1))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)), + ), + SuccessTestCase( + name = "cast decimal(1,0) string", + query = "CAST('1' AS DECIMAL(1,0))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)), + ), + SuccessTestCase( + name = "cast decimal(1,1) string", + query = "CAST('1' AS DECIMAL(1,1))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 1)), + ), + SuccessTestCase( + name = "cast decimal(38) string", + query = "CAST('1' AS DECIMAL(38))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)), + ), + SuccessTestCase( + name = "cast decimal(38,0) string", + query = "CAST('1' AS DECIMAL(38,0))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)), + ), + SuccessTestCase( + name = "cast decimal(38,38) string", + query = "CAST('1' AS DECIMAL(38,38))", + expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 38)), + ), + ) + @JvmStatic fun selectStar() = listOf( SuccessTestCase( @@ -3452,6 +3528,11 @@ class PlanTyperTestsPorted { @Execution(ExecutionMode.CONCURRENT) fun testCollections(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("decimalCastCases") + @Execution(ExecutionMode.CONCURRENT) + fun testDecimalCast(tc: TestCase) = runTest(tc) + @ParameterizedTest @MethodSource("selectStar") @Execution(ExecutionMode.CONCURRENT)