From 83e519c570a2b6c8612fa455cfa85200a969a995 Mon Sep 17 00:00:00 2001 From: Tamas Vajk Date: Thu, 21 Nov 2024 15:46:57 +0100 Subject: [PATCH] KE2: Extract `when` expressions --- .../src/main/kotlin/entities/Expression.kt | 123 ++++++++++++++---- java/ql/lib/config/semmlecode.dbscheme | 20 +++ java/ql/lib/semmle/code/java/Expr.qll | 61 ++++++++- 3 files changed, 175 insertions(+), 29 deletions(-) diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt index 72723446fdce..ab93906be720 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt @@ -808,6 +808,10 @@ private fun KotlinFileExtractor.extractExpression( return extractIf(e, parent, callable) } + is KtWhenExpression -> { + return extractWhen(e, parent, callable) + } + is KtWhileExpression -> { extractLoopWithCondition(e, parent, callable) } @@ -1160,32 +1164,6 @@ private fun KotlinFileExtractor.extractExpression( return } - val exprParent = parent.expr(e, callable) - val id = tw.getFreshIdLabel() - val type = useType(e.type) - val locId = tw.getLocation(e) - tw.writeExprs_whenexpr( - id, - type.javaResult.id, - exprParent.parent, - exprParent.idx - ) - tw.writeExprsKotlinType(id, type.kotlinResult.id) - extractExprContext(id, locId, callable, exprParent.enclosingStmt) - if (e.origin == IrStatementOrigin.IF) { - tw.writeWhen_if(id) - } - e.branches.forEachIndexed { i, b -> - val bId = tw.getFreshIdLabel() - val bLocId = tw.getLocation(b) - tw.writeStmts_whenbranch(bId, id, i, callable) - tw.writeHasLocation(bId, bLocId) - extractExpressionExpr(b.condition, callable, bId, 0, bId) - extractExpressionStmt(b.result, callable, bId, 1) - if (b is IrElseBranch) { - tw.writeWhen_branch_else(bId) - } - } } is IrGetClass -> { val exprParent = parent.expr(e, callable) @@ -1737,6 +1715,99 @@ private fun KotlinFileExtractor.extractLoop( return id } +context(KaSession) +private fun KotlinFileExtractor.extractWhen( + e: KtWhenExpression, + parent: StmtExprParent, + callable: Label +): Label? { + val exprParent = parent.expr(e, callable) + val id = tw.getFreshIdLabel() + val type = useType(e.expressionType) + val locId = tw.getLocation(e) + tw.writeExprs_whenexpr( + id, + type.javaResult.id, + exprParent.parent, + exprParent.idx + ) + tw.writeExprsKotlinType(id, type.kotlinResult.id) + extractExprContext(id, locId, callable, exprParent.enclosingStmt) + + /* OLD: KE1, Should we remove this `when_if` DB construct? + if (e.origin == IrStatementOrigin.IF) { + tw.writeWhen_if(id) + } + */ + + if (e.subjectVariable != null) { + extractVariableExpr(e.subjectVariable!!, callable, id, -1, exprParent.enclosingStmt) + } else if (e.subjectExpression != null) { + extractExpressionExpr(e.subjectExpression!!, callable, id, -1, exprParent.enclosingStmt) + } + + e.entries.forEachIndexed { i, b -> + val bId = tw.getFreshIdLabel() + val bLocId = tw.getLocation(b) + tw.writeStmts_whenbranch(bId, id, i, callable) + tw.writeHasLocation(bId, bLocId) + for ((idx, cond) in b.conditions.withIndex()) { + val condId = tw.getFreshIdLabel() + val locId = tw.getLocation(cond) + tw.writeStmts_whenbranchcondition(condId, bId, -1 * idx, callable) + tw.writeHasLocation(id, locId) + + when (cond) { + is KtWhenConditionWithExpression -> { + tw.writeWhen_branch_condition_with_expr(condId) + extractExpressionExpr( + cond.expression!!, + callable, + condId, + 0, + condId + ) + } + + is KtWhenConditionInRange -> { + // [!]in 1..10 + tw.writeWhen_branch_condition_with_range(condId, cond.isNegated) + extractExpressionExpr( + cond.rangeExpression!!, + callable, + condId, + 0, + condId + ) + } + + is KtWhenConditionIsPattern -> { + // [!]is Type + val type = useType(cond.typeReference?.type) + tw.writeWhen_branch_condition_with_pattern( + condId, + cond.isNegated, + type.javaResult.id, + type.kotlinResult.id + ) + } + } + } + + extractExpressionStmt(b.expression!!, callable, bId, 1) + val guardExpr = b.guard?.getExpression() + if (guardExpr != null) { + extractExpressionStmt(guardExpr, callable, bId, 2) + } + + if (b.isElse) { + tw.writeWhen_branch_else(bId) + } + } + + return id +} + context(KaSession) private fun KotlinFileExtractor.extractIf( ifStmt: KtIfExpression, diff --git a/java/ql/lib/config/semmlecode.dbscheme b/java/ql/lib/config/semmlecode.dbscheme index e0fd7f76ab86..d8680c03d33c 100644 --- a/java/ql/lib/config/semmlecode.dbscheme +++ b/java/ql/lib/config/semmlecode.dbscheme @@ -918,6 +918,7 @@ case @stmt.kind of | 23 = @yieldstmt | 24 = @errorstmt | 25 = @whenbranch +| 26 = @whenbranchcondition ; #keyset[parent,idx] @@ -1047,6 +1048,25 @@ when_if(unique int id: @whenexpr ref); /** Holds if this `when` branch was written as an `else` branch. */ when_branch_else(unique int id: @whenbranch ref); +/** Holds if this `when` branch condition has an expression. */ +when_branch_condition_with_expr( + unique int id: @whenbranchcondition ref + ); + +/** Holds if this `when` branch condition has a range. */ +when_branch_condition_with_range( + unique int id: @whenbranchcondition ref, + boolean isNegated: boolean ref + ); + +/** Holds if this `when` branch condition has a type pattern. */ +when_branch_condition_with_pattern( + unique int id: @whenbranchcondition ref, + boolean isNegated: boolean ref, + int typeid: @type ref, + int kttypeid: @kt_type ref + ); + @classinstancexpr = @newexpr | @lambdaexpr | @memberref | @propertyref @annotation = @declannotation | @typeannotation diff --git a/java/ql/lib/semmle/code/java/Expr.qll b/java/ql/lib/semmle/code/java/Expr.qll index b4ab37195e09..d813847868e8 100644 --- a/java/ql/lib/semmle/code/java/Expr.qll +++ b/java/ql/lib/semmle/code/java/Expr.qll @@ -2566,16 +2566,37 @@ class WhenExpr extends Expr, StmtParent, @whenexpr { override string getAPrimaryQlClass() { result = "WhenExpr" } /** Gets the `i`th branch. */ - WhenBranch getBranch(int i) { result.isNthChildOf(this, i) } + WhenBranch getBranch(int i) { result.isNthChildOf(this, i) and i >= 0 } /** Holds if this was written as an `if` expression. */ predicate isIf() { when_if(this) } + + /** Gets the expression of this `when` expression, if any. */ + Expr getExpr() { result.isNthChildOf(this, -1) } + + /** Gets the local variable declaration of this `when` expression, if any. */ + LocalVariableDeclExpr getAVariableDeclExpr() { result.isNthChildOf(this, -1) } } /** A Kotlin `when` branch. */ class WhenBranch extends Stmt, @whenbranch { - /** Gets the condition of this branch. */ - Expr getCondition() { result.isNthChildOf(this, 0) } + /** + * DEPRECATED: Use `getACondition` or `getCondition/1` instead. + * + * Gets the condition of this branch. + */ + deprecated Expr getCondition() { + result = this.getCondition(0).(WhenBranchConditionWithExpression).getExpression() + } + + /** Gets the `i`th condition of this branch. */ + WhenBranchCondition getCondition(int i) { i <= 0 and result.isNthChildOf(this, i) } + + /** Gets a condition of this branch. */ + WhenBranchCondition getACondition() { result = this.getCondition(_) } + + /** Gets the guard applicable to this branch, if any. */ + Expr getGuard() { result.isNthChildOf(this, 2) } /** Gets the result of this branch. */ Stmt getRhs() { result.isNthChildOf(this, 1) } @@ -2594,6 +2615,40 @@ class WhenBranch extends Stmt, @whenbranch { override string getAPrimaryQlClass() { result = "WhenBranch" } } +abstract class WhenBranchCondition extends Stmt, @whenbranchcondition { } + +/** A Kotlin `when` branch condition with an expression. */ +class WhenBranchConditionWithExpression extends WhenBranchCondition { + WhenBranchConditionWithExpression() { when_branch_condition_with_expr(this) } + + Expr getExpression() { result.isNthChildOf(this, 0) } +} + +/** A Kotlin `when` branch condition with a range. */ +class WhenBranchConditionWithRange extends WhenBranchCondition { + WhenBranchConditionWithRange() { when_branch_condition_with_range(this, _) } + + /** Holds if this is a negated range condition. */ + predicate isNegated() { when_branch_condition_with_range(this, true) } + + /** + * Gets the range of this branch condition. + * Ranges are represented by calls to `operator fun > T.rangeTo(that: T): ClosedRange`. + */ + MethodCall getRange() { result.isNthChildOf(this, 0) } +} + +/** A Kotlin `when` branch condition with a pattern. */ +class WhenBranchConditionWithPattern extends WhenBranchCondition { + WhenBranchConditionWithPattern() { when_branch_condition_with_pattern(this, _, _, _) } + + /** Holds if this is a negated pattern condition. */ + predicate isNegated() { when_branch_condition_with_pattern(this, true, _, _) } + + /** Gets the type pattern of this branch condition. */ + Type getType() { when_branch_condition_with_pattern(this, _, result, _) } +} + // TODO: This might need more cases. It might be better as a predicate // on Stmt, overridden in each subclass. private Expr getAResult(Stmt s) {