Skip to content

Commit

Permalink
Fixes small evaluation bugs related to casts, types, etc (#1577)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn authored Oct 2, 2024
1 parent bb2947f commit 433c609
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 125 deletions.
2 changes: 1 addition & 1 deletion partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ internal class MainCommand : Runnable {
}
is PartiQLResult.Value -> {
val writer = PartiQLValueTextWriter(System.out)
writer.append(result.value)
writer.append(result.value.toPartiQLValue()) // TODO: Create a Datum writer
println()
}
}
Expand Down
2 changes: 1 addition & 1 deletion partiql-cli/src/main/kotlin/org/partiql/cli/shell/Shell.kt
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ internal class Shell(
is PartiQLResult.Error -> throw result.cause
is PartiQLResult.Value -> {
val writer = PartiQLValueTextWriter(out)
writer.append(result.value)
writer.append(result.value.toPartiQLValue()) // TODO: Create a Datum writer
out.appendLine()
out.appendLine()
out.info("OK!")
Expand Down
10 changes: 5 additions & 5 deletions partiql-eval/api/partiql-eval.api
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ public final class org/partiql/eval/PartiQLResult$Error : org/partiql/eval/Parti
}

public final class org/partiql/eval/PartiQLResult$Value : org/partiql/eval/PartiQLResult {
public fun <init> (Lorg/partiql/value/PartiQLValue;)V
public final fun component1 ()Lorg/partiql/value/PartiQLValue;
public final fun copy (Lorg/partiql/value/PartiQLValue;)Lorg/partiql/eval/PartiQLResult$Value;
public static synthetic fun copy$default (Lorg/partiql/eval/PartiQLResult$Value;Lorg/partiql/value/PartiQLValue;ILjava/lang/Object;)Lorg/partiql/eval/PartiQLResult$Value;
public fun <init> (Lorg/partiql/spi/value/Datum;)V
public final fun component1 ()Lorg/partiql/spi/value/Datum;
public final fun copy (Lorg/partiql/spi/value/Datum;)Lorg/partiql/eval/PartiQLResult$Value;
public static synthetic fun copy$default (Lorg/partiql/eval/PartiQLResult$Value;Lorg/partiql/spi/value/Datum;ILjava/lang/Object;)Lorg/partiql/eval/PartiQLResult$Value;
public fun equals (Ljava/lang/Object;)Z
public final fun getValue ()Lorg/partiql/value/PartiQLValue;
public final fun getValue ()Lorg/partiql/spi/value/Datum;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package org.partiql.eval

import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.spi.value.Datum

public sealed interface PartiQLResult {

@OptIn(PartiQLValueExperimental::class)
public data class Value(public val value: PartiQLValue) : PartiQLResult
public data class Value(public val value: Datum) : PartiQLResult

public data class Error(public val cause: Throwable) : PartiQLResult
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ import org.partiql.eval.PartiQLStatement
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.catalog.Session
import org.partiql.value.PartiQLValueExperimental

internal class QueryStatement(root: Operator.Expr) : PartiQLStatement {

// DO NOT USE FINAL
private var _root = root

@OptIn(PartiQLValueExperimental::class)
override fun execute(session: Session): PartiQLResult {
val datum = _root.eval(Environment.empty)
val value = PartiQLResult.Value(datum.toPartiQLValue())
val value = PartiQLResult.Value(datum)
return value
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.partiql.eval.internal

import com.amazon.ionelement.api.createIonElementLoader
import com.amazon.ionelement.api.loadSingleElement
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
Expand All @@ -12,16 +11,15 @@ import org.partiql.eval.PartiQLEngine
import org.partiql.eval.PartiQLResult
import org.partiql.parser.PartiQLParser
import org.partiql.plan.v1.PartiQLPlan
import org.partiql.planner.builder.PartiQLPlannerBuilder
import org.partiql.planner.internal.SqlPlannerV1
import org.partiql.plugins.memory.MemoryCatalog
import org.partiql.plugins.memory.MemoryTable
import org.partiql.spi.catalog.Name
import org.partiql.spi.catalog.Session
import org.partiql.spi.value.Datum
import org.partiql.spi.value.ion.IonDatum
import org.partiql.types.PType
import org.partiql.types.StaticType
import org.partiql.value.CollectionValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.bagValue
Expand All @@ -36,6 +34,7 @@ import org.partiql.value.missingValue
import org.partiql.value.nullValue
import org.partiql.value.stringValue
import org.partiql.value.structValue
import org.partiql.value.symbolValue
import java.io.ByteArrayOutputStream
import java.math.BigDecimal
import java.math.BigInteger
Expand Down Expand Up @@ -77,8 +76,70 @@ class PartiQLEngineDefaultTest {
@Execution(ExecutionMode.CONCURRENT)
fun globalsTests(tc: SuccessTestCase) = tc.assert()

@ParameterizedTest
@MethodSource("castTestCases")
@Execution(ExecutionMode.CONCURRENT)
fun castTests(tc: SuccessTestCase) = tc.assert()

companion object {

@JvmStatic
fun castTestCases() = listOf(
SuccessTestCase(
input = """
CAST(20 AS DECIMAL(10, 5));
""".trimIndent(),
expected = decimalValue(BigDecimal.valueOf(2000000, 5))
),
SuccessTestCase(
input = """
CAST(20 AS DECIMAL(10, 3));
""".trimIndent(),
expected = decimalValue(BigDecimal.valueOf(20000, 3))
),
SuccessTestCase(
input = """
CAST(20 AS DECIMAL(2, 0));
""".trimIndent(),
expected = decimalValue(BigDecimal.valueOf(20, 0))
),
SuccessTestCase(
input = """
CAST(20 AS DECIMAL(1, 0));
""".trimIndent(),
expected = missingValue(),
mode = PartiQLEngine.Mode.PERMISSIVE
),
SuccessTestCase(
input = """
1 + 2.0
""".trimIndent(),
expected = decimalValue(BigDecimal.valueOf(30, 1))
),
SuccessTestCase(
input = "SELECT DISTINCT VALUE t * 100 FROM <<0, 1, 2.0, 3.0>> AS t;",
expected = bagValue(
int32Value(0),
int32Value(100),
decimalValue(BigDecimal.valueOf(2000, 1)),
decimalValue(BigDecimal.valueOf(3000, 1)),
)
),
SuccessTestCase(
input = """
CAST(20 AS SYMBOL);
""".trimIndent(),
expected = symbolValue("20"),
),
// TODO: Use Datum for assertions. Currently, PartiQLValue doesn't support parameterized CHAR/VARCHAR
// SuccessTestCase(
// input = """
// CAST(20 AS CHAR(2));
// """.trimIndent(),
// expected = charValue("20"),
// ),
)

@JvmStatic
fun globalsTestCases() = listOf(
SuccessTestCase(
Expand Down Expand Up @@ -1246,9 +1307,7 @@ class PartiQLEngineDefaultTest {
) {

private val engine = PartiQLEngine.builder().build()
private val planner = PartiQLPlannerBuilder().build()
private val parser = PartiQLParser.default()
private val loader = createIonElementLoader()

/**
* @property value is a serialized Ion value.
Expand Down Expand Up @@ -1288,7 +1347,7 @@ class PartiQLEngineDefaultTest {
throw returned.cause
}
}
val output = result.value
val output = result.value.toPartiQLValue() // TODO: Assert directly on Datum
assert(expected == output) {
comparisonString(expected, output, plan)
}
Expand Down Expand Up @@ -1321,37 +1380,38 @@ class PartiQLEngineDefaultTest {
) {

private val engine = PartiQLEngine.builder().build()
private val planner = PartiQLPlannerBuilder().build()
private val parser = PartiQLParser.default()

internal fun assert() {
val permissiveResult = run(mode = PartiQLEngine.Mode.PERMISSIVE)
val (permissiveResult, plan) = run(mode = PartiQLEngine.Mode.PERMISSIVE)
val permissiveResultPValue = permissiveResult.toPartiQLValue()
val assertionCondition = try {
expectedPermissive == permissiveResult.first
expectedPermissive == permissiveResultPValue // TODO: Assert using Datum
} catch (t: Throwable) {
val str = buildString {
appendLine("Test Name: $name")
// TODO pretty-print V1 plans!
appendLine(permissiveResult.second)
appendLine(plan)
}
throw RuntimeException(str, t)
}
assert(assertionCondition) {
comparisonString(expectedPermissive, permissiveResult.first, permissiveResult.second)
comparisonString(expectedPermissive, permissiveResultPValue, plan)
}
var error: Throwable? = null
try {
when (val result = run(mode = PartiQLEngine.Mode.STRICT).first) {
is CollectionValue<*> -> result.toList()
else -> result
val (strictResult, _) = run(mode = PartiQLEngine.Mode.STRICT)
when (strictResult.type.kind) {
PType.Kind.BAG, PType.Kind.ARRAY, PType.Kind.SEXP -> strictResult.toList()
else -> strictResult
}
} catch (e: Throwable) {
error = e
}
assertNotNull(error)
}

private fun run(mode: PartiQLEngine.Mode): Pair<PartiQLValue, PartiQLPlan> {
private fun run(mode: PartiQLEngine.Mode): Pair<Datum, PartiQLPlan> {
val statement = parser.parse(input).root
val catalog = MemoryCatalog.builder().name("memory").build()
val session = Session.builder()
Expand Down Expand Up @@ -1419,20 +1479,6 @@ class PartiQLEngineDefaultTest {
tc.assert()
}

@Test
@Disabled("CASTS have not yet been implemented.")
fun testCast1() = SuccessTestCase(
input = "1 + 2.0",
expected = int32Value(3),
).assert()

@Test
@Disabled("CASTS have not yet been implemented.")
fun testCasts() = SuccessTestCase(
input = "SELECT DISTINCT VALUE t * 100 FROM <<0, 1, 2.0, 3.0>> AS t;",
expected = bagValue(int32Value(0), int32Value(100), int32Value(200), int32Value(300))
).assert()

@Test
@Disabled("We need to support section 5.1")
fun testTypingOfPositionVariable() = TypingTestCase(
Expand Down
Loading

0 comments on commit 433c609

Please sign in to comment.