diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index 0dd1a879ff..91506e2fc7 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -3,14 +3,9 @@ package org.partiql.ast.helpers import com.amazon.ion.Decimal -import com.amazon.ionelement.api.DecimalElement -import com.amazon.ionelement.api.FloatElement -import com.amazon.ionelement.api.IntElement -import com.amazon.ionelement.api.IntElementSize import com.amazon.ionelement.api.MetaContainer import com.amazon.ionelement.api.emptyMetaContainer import com.amazon.ionelement.api.ionDecimal -import com.amazon.ionelement.api.ionFloat import com.amazon.ionelement.api.ionInt import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionSymbol @@ -48,7 +43,6 @@ import org.partiql.value.TimestampValue import org.partiql.value.datetime.TimeZone import org.partiql.value.toIon import java.math.BigDecimal -import java.math.BigInteger /** * Translates an [AstNode] tree to the legacy PIG AST. @@ -322,42 +316,8 @@ private class AstTranslator(val metas: Map) : AstBaseVisi val arg = visitExpr(node.expr, ctx) when (node.op) { Expr.Unary.Op.NOT -> not(arg, metas) - Expr.Unary.Op.POS -> { - when { - arg !is PartiqlAst.Expr.Lit -> pos(arg) - arg.value is IntElement -> arg - arg.value is FloatElement -> arg - arg.value is DecimalElement -> arg - else -> pos(arg) - } - } - Expr.Unary.Op.NEG -> { - when { - arg !is PartiqlAst.Expr.Lit -> neg(arg, metas) - arg.value is IntElement -> { - val intValue = when (arg.value.integerSize) { - IntElementSize.LONG -> ionInt(-arg.value.longValue) - IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) { - Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE) - else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L)) - } - } - arg.copy( - value = intValue.asAnyElement(), - metas = metas, - ) - } - arg.value is FloatElement -> arg.copy( - value = ionFloat(-(arg.value.doubleValue)).asAnyElement(), - metas = metas, - ) - arg.value is DecimalElement -> arg.copy( - value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(), - metas = metas, - ) - else -> neg(arg, metas) - } - } + Expr.Unary.Op.POS -> pos(arg, metas) + Expr.Unary.Op.NEG -> neg(arg, metas) } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt new file mode 100644 index 0000000000..22ca8668b6 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt @@ -0,0 +1,74 @@ +package org.partiql.lang.syntax + +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +class PartiQLParserLiteralTests : PartiQLParserTestBase() { + + override val targets: Array = arrayOf(ParserTarget.EXPERIMENTAL) + + @ParameterizedTest + @MethodSource("cases") + @Execution(ExecutionMode.CONCURRENT) + fun testAll(case: Case) { + assertExpression( + source = case.input, + expectedPigAst = case.expect, + ) + Long.MAX_VALUE + } + + companion object { + + @JvmStatic + fun cases() = listOf( + Case( + input = "1", + expect = "(lit 1)" + ), + Case( + input = "+-1", + expect = "(lit -1)" + ), + Case( + input = "-+1", + expect = "(lit -1)" + ), + Case( + input = "-+-1", + expect = "(lit 1)" + ), + Case( + input = "+++1", + expect = "(lit 1)" + ), + Case( + input = "-1", + expect = "(lit -1)" + ), + Case( + input = "+1", + expect = "(lit 1)" + ), + Case( + input = "9223372036854775808", // Long.MAX_VALUE + 1 + expect = "(lit 9223372036854775808)" + ), + Case( + input = "-9223372036854775809", // Long.MIN_VALUE - 1 + expect = "(lit -9223372036854775809)" + ), + Case( + input = "+9223372036854775808", + expect = "(lit 9223372036854775808)" + ), + ) + } + + class Case( + @JvmField val input: String, + @JvmField val expect: String, + ) +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt index a72fefa81f..c781ba9be7 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt @@ -295,6 +295,7 @@ class PartiQLParserTest : PartiQLParserTestBase() { } @Test + @Ignore("Disabled because it's not clear that the parser should be pushing down negations on boxed Ion values") fun unaryIonFloatLiteral() { assertExpression( "+-+-+-`-5e0`", diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index ac30689f17..4ba104014c 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -209,6 +209,7 @@ import org.partiql.parser.SourceLocation import org.partiql.parser.SourceLocations import org.partiql.parser.antlr.PartiQLParserBaseVisitor import org.partiql.parser.internal.util.DateTimeUtils +import org.partiql.parser.internal.util.NumberUtils.negate import org.partiql.value.NumericValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.StringValue @@ -569,17 +570,18 @@ internal class PartiQLParserDefault : PartiQLParser { } } - override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = translate(ctx) { - val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } - val name = visitSymbolPrimitive(ctx.name) - if (qualifier.isEmpty()) { - name - } else { - val root = qualifier.first() - val steps = qualifier.drop(1) + listOf(name) - identifierQualified(root, steps) + override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = + translate(ctx) { + val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } + val name = visitSymbolPrimitive(ctx.name) + if (qualifier.isEmpty()) { + name + } else { + val root = qualifier.first() + val steps = qualifier.drop(1) + listOf(name) + identifierQualified(root, steps) + } } - } /** * @@ -1488,9 +1490,34 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - val expr = visit(ctx.rhs) as Expr - exprUnary(convertUnaryOp(ctx.sign), expr) + // expression + if (ctx.parent != null) { + return@translate visit(ctx.parent) + } + // unary expression + val op = when (ctx.sign.type) { + GeneratedParser.NOT -> Expr.Unary.Op.NOT + GeneratedParser.PLUS -> Expr.Unary.Op.POS + GeneratedParser.MINUS -> Expr.Unary.Op.NEG + else -> throw error(ctx.sign, "Invalid unary operator") + } + // If argument is not a literal, then return the op. + val arg = visit(ctx.rhs) as Expr + return when (arg) { + is Expr.Lit -> arg.negate(op) + // TODO should we unwrap and negate Ion values for -`-1`? I don't think so.. + is Expr.Ion -> exprUnary(op, arg) + else -> exprUnary(op, arg) + } + } + + private fun Expr.Lit.negate(op: Expr.Unary.Op): Expr { + val v = this.value + return when { + op == Expr.Unary.Op.POS && v is NumericValue<*> -> exprLit(v) + op == Expr.Unary.Op.NEG && v is NumericValue<*> -> exprLit(v.negate()) + else -> exprUnary(op, exprLit(v)) + } } private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr { diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/util/NumberUtils.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/util/NumberUtils.kt new file mode 100644 index 0000000000..b84c8edff9 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/util/NumberUtils.kt @@ -0,0 +1,55 @@ +package org.partiql.parser.internal.util + +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.NumericValue +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import java.math.BigInteger + +internal object NumberUtils { + + /** + * We might consider a `negate` method on the NumericValue but this is fine for now and is internal. + */ + @OptIn(PartiQLValueExperimental::class) + internal fun NumericValue<*>.negate(): NumericValue<*> = when (this) { + is DecimalValue -> decimalValue(value?.negate()) + is Float32Value -> float32Value(value?.let { it * -1 }) + is Float64Value -> float64Value(value?.let { it * -1 }) + is Int8Value -> when (value) { + null -> this + Byte.MIN_VALUE -> int16Value(value?.let { (it.toInt() * -1).toShort() }) + else -> int8Value(value?.let { (it.toInt() * -1).toByte() }) + } + is Int16Value -> when (value) { + null -> this + Short.MIN_VALUE -> int32Value(value?.let { it.toInt() * -1 }) + else -> int16Value(value?.let { (it.toInt() * -1).toShort() }) + } + is Int32Value -> when (value) { + null -> this + Int.MIN_VALUE -> int64Value(value?.let { it.toLong() * -1 }) + else -> int32Value(value?.let { it * -1 }) + } + is Int64Value -> when (value) { + null -> this + Long.MIN_VALUE -> intValue(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE)) + else -> int64Value(value?.let { it * -1 }) + } + is IntValue -> intValue(value?.negate()) + } +} diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/util/NumberUtilsTest.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/util/NumberUtilsTest.kt new file mode 100644 index 0000000000..22852e9a77 --- /dev/null +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/util/NumberUtilsTest.kt @@ -0,0 +1,36 @@ +package org.partiql.parser.internal.util + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.partiql.parser.internal.util.NumberUtils.negate +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.decimalValue +import org.partiql.value.int16Value +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.int8Value +import org.partiql.value.intValue +import java.math.BigDecimal +import java.math.BigInteger + +@OptIn(PartiQLValueExperimental::class) +class NumberUtilsTest { + + @Test + fun negate_normal() { + assertEquals(int8Value(-1), int8Value(1).negate()) + assertEquals(int16Value(-1), int16Value(1).negate()) + assertEquals(int32Value(-1), int32Value(1).negate()) + assertEquals(int64Value(-1), int64Value(1).negate()) + assertEquals(intValue(BigInteger.valueOf(-1L)), intValue(BigInteger.valueOf(1L)).negate()) + assertEquals(decimalValue(BigDecimal.valueOf(-1L)), decimalValue(BigDecimal.valueOf(1L)).negate()) + } + + @Test + fun negate_overflow() { + assertEquals(int16Value((Byte.MAX_VALUE.toShort() + 1).toShort()), int8Value(Byte.MIN_VALUE).negate()) + assertEquals(int32Value((Short.MAX_VALUE.toInt() + 1)), int16Value(Short.MIN_VALUE).negate()) + assertEquals(int64Value((Int.MAX_VALUE.toLong() + 1)), int32Value(Int.MIN_VALUE).negate()) + assertEquals(intValue(BigInteger.valueOf(Long.MAX_VALUE) + BigInteger.ONE), int64Value(Long.MIN_VALUE).negate()) + } +}