Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime evaluation improvements #741

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ object RuntimeEvaluationTree {
def args: Seq[RuntimeEvaluationTree]
}

sealed trait Assignable extends RuntimeEvaluationTree
sealed trait Assignable extends RuntimeEvaluationTree {
def isMutable: Boolean
}

sealed trait Field extends Assignable {
def field: jdi.Field
def isMutable: Boolean = !field.isFinal
override def isMutable: Boolean = !field.isFinal
}

case class LocalVar(name: String, `type`: jdi.Type) extends Assignable {
override val isMutable: Boolean = true
override def prettyPrint(depth: Int): String = {
val indent = " " * (depth + 1)
s"""|LocalVar(
Expand Down Expand Up @@ -125,7 +128,7 @@ object RuntimeEvaluationTree {
}

case class NewInstance(init: CallStaticMethod) extends RuntimeEvaluationTree {
override lazy val `type`: jdi.ReferenceType = init.method.declaringType() // .asInstanceOf[jdi.ClassType]
override lazy val `type`: jdi.ReferenceType = init.method.declaringType()
override def prettyPrint(depth: Int): String = {
val indent = " " * (depth + 1)
s"""|NewInstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
.orElse(findTopLevelModule(name))
case _: Term.This => thisTree
case sup: Term.Super => Recoverable("Super not (yet) supported at runtime")
case _: Term.Apply | _: Term.ApplyInfix | _: Term.ApplyUnary => validateMethod(standardize(expression))
case expression @ Term.Apply.After_4_6_0(Term.Name(fun), _) =>
for {
call <- standardize(expression)
method <-
validateMethod(call)
.orElse(findClass("Predef").flatMap(x => findStaticMethodBySignedName(x.`type`, fun, call.args)))
} yield method
case _: Term.Apply | _: Term.ApplyInfix | _: Term.ApplyUnary => standardize(expression).flatMap(validateMethod)
case select: Term.Select =>
validateAsValueOrClass(select.qual)
.flatMap {
Expand Down Expand Up @@ -125,7 +132,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
case (err: Invalid, _) => err
}

private def unitTree: Validation[RuntimeEvaluationTree] = validateLiteral(Lit.Unit())
private val unitTree: Validation[RuntimeEvaluationTree] = validateLiteral(Lit.Unit())

private def validateLiteral(lit: Lit): Validation[RuntimeEvaluationTree] =
classLoader.map { loader =>
Expand All @@ -139,41 +146,36 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
qualifier: RuntimeEvaluationTree,
name: String,
preevaluate: Boolean = preEvaluation
): Validation[RuntimeEvaluationTree] = {
): Validation[RuntimeEvaluationTree] =
for {
qualifierTpe <- asReference(qualifier.`type`)
field <- findField(name, qualifierTpe)
_ = loadClassOnNeed(field)
fieldTree <- asInstanceField(field, qualifier, preevaluate)
} yield fieldTree
}

private def findStaticField(
qualifier: jdi.ReferenceType,
name: String,
preevaluate: Boolean = preEvaluation
): Validation[RuntimeEvaluationTree] = {
for {
field <- findField(name, qualifier)
_ = loadClassOnNeed(field)
fieldTree <- asStaticField(field, preevaluate = preevaluate)
} yield fieldTree
}
): Validation[RuntimeEvaluationTree] =
findField(name, qualifier).flatMap(asStaticField(_, preevaluate))

private def findVariable(name: String, preevaluate: Boolean = preEvaluation): Validation[RuntimeEvaluationTree] = {
val encodedName = NameTransformer.encode(name)
Validation
.fromOption(frame.variableByName(encodedName), s"$name is not a local variable")
.filter(_.`type`.name != "scala.Function0", v => s"${v.name} could be a by-name argument")
.map(v => LocalVar(encodedName, v.`type`))
.map(v => if (preevaluate) preEvaluate(v) else v)
.filter("value could be a by-name argument")(_.`type`.name != "scala.Function0")
.map { localVar =>
val v = LocalVar(encodedName, localVar.`type`)
if (preevaluate) preEvaluate(v) else v
}
}

private def findField(name: String, ref: jdi.ReferenceType): Validation[jdi.Field] = {
val encodedName = NameTransformer.encode(name)
def fieldOpt = Option(ref.fieldByName(encodedName))
.orElse(ref.visibleFields.asScala.find(_.name.endsWith("$" + encodedName)))
Validation.fromOption(fieldOpt, s"$name is not a field in ${ref.name}")
Validation.fromOption(fieldOpt, s"$name is not a field in ${ref.name}").tap(loadClassOnNeed)
}

private sealed trait RuntimeClass {
Expand All @@ -190,7 +192,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source

private def findStaticClass(name: String, qualifier: jdi.ReferenceType): Validation[RuntimeClass] =
findInnerClass(name, qualifier)
.filter(_.isStatic, cls => s"${cls.name} is not a static class")
.filter("class is not a static class")(_.isStatic)
.map(StaticOrTopLevelClass.apply)

private def findMemberClass(name: String, qualifier: RuntimeEvaluationTree): Validation[RuntimeClass] =
Expand Down Expand Up @@ -258,13 +260,11 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
.orElse(findOuter(qualifier).flatMap(findMemberOrModule(name, _)))

private def findModule(name: String, qualifier: RuntimeEvaluationTree): Validation[RuntimeEvaluationTree] = {
val moduleClassName = name.stripSuffix("$") + "$"
val moduleClassName = toModuleName(name)
val qualifierTypeName = qualifier.`type`.name
findMemberClass(moduleClassName, qualifier).flatMap { cls =>
if (inCompanion(qualifierTypeName, moduleClassName))
Recoverable(s"Cannot access module $name from $qualifierTypeName")
else asModule(cls.`type`, qualifier)
}
if (inCompanion(qualifierTypeName, moduleClassName))
Recoverable(s"Cannot access module $name from $qualifierTypeName")
else findMemberClass(moduleClassName, qualifier).flatMap(cls => asModule(cls.`type`, qualifier))
}

private def inCompanion(qualifierTypeName: String, moduleClassName: String) =
Expand All @@ -273,48 +273,50 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
.withFilter(_.methodsByName(moduleClassName.stripSuffix("$")).asScala.nonEmpty)
.isValid

private def toModuleName(name: String) = if (name.endsWith("$")) name else name + "$"
private def findTopLevelModule(name: String): Validation[RuntimeEvaluationTree] =
findTopLevelClass(name.stripSuffix("$") + "$").flatMap(cls => asStaticModule(cls.`type`))
findTopLevelClass(toModuleName(name)).flatMap(cls => asStaticModule(cls.`type`))

private def findTopLevelModule(pkg: String, name: String): Validation[RuntimeEvaluationTree] =
findQualifiedClass(name.stripSuffix("$") + "$", pkg).flatMap(asStaticModule)
findQualifiedClass(toModuleName(name), pkg).flatMap(asStaticModule)

private def findStaticMember(name: String, qualifier: jdi.ReferenceType): Validation[RuntimeEvaluationTree] =
findStaticField(qualifier, name).orElse(findZeroArgStaticMethod(qualifier, name))

/* Standardized method call */
private case class Call(fun: Term, argClause: Term.ArgClause)
private case class Call(fun: Term, args: Seq[RuntimeEvaluationTree])

private def standardize(apply: Stat): Call =
private def standardize(apply: Stat): Validation[Call] =
apply match {
case apply: Term.Apply => Call(apply.fun, apply.argClause)
case apply: Term.Apply => apply.argClause.map(validateAsValue).traverse.map(Call(apply.fun, _))
case Term.ApplyInfix.After_4_6_0(lhs, op, _, argClause) if op.value.endsWith(":") =>
Call(Term.Select(argClause.head, op), List(lhs))
case apply: Term.ApplyInfix => Call(Term.Select(apply.lhs, apply.op), apply.argClause)
case apply: Term.ApplyUnary => Call(Term.Select(apply.arg, Term.Name("unary_" + apply.op)), List.empty)
validateAsValue(lhs).map(arg => Call(Term.Select(argClause.head, op), List(arg)))
case apply: Term.ApplyInfix =>
validateAsValue(apply.argClause(0)).map(arg => Call(Term.Select(apply.lhs, apply.op), Seq(arg)))
case apply: Term.ApplyUnary => Valid(Call(Term.Select(apply.arg, Term.Name("unary_" + apply.op)), List.empty))
}

private def validateMethod(call: Call): Validation[RuntimeEvaluationTree] = {
call.argClause.map(validateAsValue).traverse.flatMap { args =>
call.fun match {
case Term.Select(qualifier, Name(name)) =>
validateAsValueOrClass(qualifier)
.flatMap {
case Left(qualifier) =>
asPrimitiveOp(qualifier, name, args)
.orElse(
findMethodBySignedName(qualifier, name, args),
resetError = isReference(qualifier) || args.size > 2
)
case Right(qualifier) => findStaticMethodBySignedName(qualifier, name, args)
}
.orElse(validateAsValue(call.fun).flatMap(findApplyMethod(_, args)))
case Term.Name(name) =>
thisTree
.flatMap(findMethodInThisOrOuter(_, name, args))
.orElse(validateAsValue(call.fun).flatMap(findApplyMethod(_, args)))
.orElse(declaringType.flatMap(findStaticMethodBySignedName(_, name, args)))
}
call.fun match {
case Term.Select(qualifier, Name(name)) =>
validateAsValueOrClass(qualifier)
.flatMap {
case Left(qualifier) =>
logger.info(s"args: ${call.args.map(_.`type`).mkString(", ")}")
asPrimitiveOp(qualifier, name, call.args)
.orElse(
findMethodBySignedName(qualifier, name, call.args),
resetError = isReference(qualifier) || call.args.size > 2
)
.orElseIf(qualifier.`type`.name() == "String")({logger.info("here"); catchStringOpsMethods(qualifier, name, call.args)})
case Right(qualifier) => findStaticMethodBySignedName(qualifier, name, call.args)
}
.orElse(validateAsValue(call.fun).flatMap(findApplyMethod(_, call.args)))
case Term.Name(name) =>
thisTree
.flatMap(findMethodInThisOrOuter(_, name, call.args))
.orElse(validateAsValue(call.fun).flatMap(findApplyMethod(_, call.args)))
.orElse(declaringType.flatMap(findStaticMethodBySignedName(_, name, call.args)))
}
}

Expand Down Expand Up @@ -351,7 +353,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
args <- argClauses.flatMap(_.map(validateAsValue)).traverse
cls <- findClass(tpe)
allArgs = extractCapture(cls).toSeq ++ args
init <- findMethodBySignedName(cls.`type`, "<init>", allArgs.map(_.`type`))
init <- findBestMethodBySignedName(cls.`type`, "<init>", allArgs.map(_.`type`))
} yield NewInstance(CallStaticMethod(init, allArgs, cls.`type`))
}

Expand Down Expand Up @@ -413,19 +415,11 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
}

for {
lhs <- lhs.flatMap {
case field: Field =>
if (field.isMutable) Valid(field) else Recoverable(s"${field.field.name} is not mutable")
case localVar: LocalVar => Valid(localVar)
case _ => Recoverable(s"${tree.lhs} is neither a variable nor a field")
}
lhs <- lhs
.collect { case a: Assignable if a.isMutable => a }
rhs <- validateAsValue(tree.rhs)
.filter(
rhs => isAssignableFrom(rhs.`type`, lhs.`type`),
rhs => s"Cannot assign ${nameOrNull(rhs.`type`)} to ${nameOrNull(lhs.`type`)}"
)
unit <- unitTree
} yield Assign(lhs, rhs, unit.`type`)
.filter("Cannot assign value")(rhs => isAssignableFrom(rhs.`type`, lhs.`type`))
} yield Assign(lhs, rhs, unitTree.get.`type`)
}

private def isReference(tree: RuntimeEvaluationTree): Boolean =
Expand All @@ -443,8 +437,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
.zip(m2.argumentTypes.asScala)
.forall {
case (t1, t2) if nameOrNull(t1) == nameOrNull(t2) => true
case (_: jdi.PrimitiveType, _) => true
case (_, _: jdi.PrimitiveType) => true
case (_: jdi.PrimitiveType, _) | (_, _: jdi.PrimitiveType) => true
case (r1: jdi.ReferenceType, r2: jdi.ReferenceType) => isAssignableFrom(r1, r2)
}
}
Expand Down Expand Up @@ -479,12 +472,14 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
* @param args the arguments types of the method
* @return the method, wrapped in a [[Validation]]
*/
private def findMethodBySignedName(
private def findBestMethodBySignedName(
ref: jdi.ReferenceType,
name: String,
args: Seq[jdi.Type]
): Validation[jdi.Method] = {
val candidates = findMethodsByName(ref, name)
logger.debug(s"Found methods: ${candidates.mkString("\n -> ", "\n -> ", "\n")}")
logger.debug(s"Provided args type: ${args.map(nameOrNull).mkString("[", ", ", "]")}")
val unboxedCandidates = candidates.filter(matchArguments(_, args, boxing = false))
val boxedCandidates = unboxedCandidates.size match {
case 0 => candidates.filter(matchArguments(_, args, boxing = true))
Expand All @@ -506,18 +501,18 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source

private def findZeroArgMethod(qualifier: RuntimeEvaluationTree, name: String): Validation[CallMethod] =
asReference(qualifier.`type`)
.flatMap(zeroArgMethodByName(_, name))
.flatMap { m =>
if (isModuleCall(m)) Recoverable("Accessing a module from its instanciation method is not allowed")
else asInstanceMethod(m, Seq.empty, qualifier)
.flatMap {
findBestMethodBySignedName(_, name, Seq())
.filterNot("Accessing a module from its instanciation method is not allowed")(isModuleCall)
.flatMap(asInstanceMethod(_, Seq.empty, qualifier))
}
.orElse(catchStringOpsMethods(qualifier, name, Seq.empty))

private def findZeroArgStaticMethod(qualifier: jdi.ReferenceType, name: String): Validation[CallMethod] =
zeroArgMethodByName(qualifier, name)
.flatMap { m =>
if (isModuleCall(m)) Recoverable("Accessing a module from its instanciation method is not allowed")
else asStaticMethod(m, Seq.empty, qualifier)
}
private def findZeroArgStaticMethod(qualifier: jdi.ReferenceType, name: String): Validation[CallStaticMethod] =
findBestMethodBySignedName(qualifier, name, Seq()).flatMap { m =>
if (isModuleCall(m)) Recoverable("Accessing a module from its instanciation method is not allowed")
else asStaticMethod(m, Seq.empty, qualifier)
}

private def isModuleCall(m: jdi.Method): Boolean = {
val rt = m.returnTypeName
Expand All @@ -527,18 +522,6 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
noArgs && isSingleton && isSingletonInstantiation
}

private def zeroArgMethodByName(ref: jdi.ReferenceType, name: String): Validation[jdi.Method] = {
findMethodsByName(ref, name).filter(_.argumentTypeNames.isEmpty) match {
case Seq() => Recoverable(s"Cannot find method $name with no args in ${ref.name}")
case Seq(method) => Valid(method).map(loadClassOnNeed)
case methods =>
methods
.filterNot(_.isBridge())
.validateSingle(s"Cannot find method $name with no args in ${ref.name}")
.map(loadClassOnNeed)
}
}

private def findMethodsByName(ref: jdi.ReferenceType, name: String): Seq[jdi.Method] = {
val encodedName = if (name == "<init>" && ref.isInstanceOf[jdi.ClassType]) name else NameTransformer.encode(name)
ref.methodsByName(encodedName).asScala.toSeq
Expand All @@ -549,20 +532,36 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
name: String,
args: Seq[RuntimeEvaluationTree]
): Validation[RuntimeEvaluationTree] = {
if (!args.isEmpty) {
asReference(qualifier.`type`)
.flatMap(tpe => findMethodBySignedName(tpe, name, args.map(_.`type`)))
.flatMap(asInstanceMethod(_, args, qualifier))
val ownMethod = if (!args.isEmpty) {
for {
qualifierType <- asReference(qualifier.`type`)
method <- findBestMethodBySignedName(qualifierType, name, args.map(_.`type`))
methodTree <- asInstanceMethod(method, args, qualifier)
} yield methodTree
} else findZeroArgMethod(qualifier, name)

ownMethod.orElseIf(qualifier.`type`.name() == "String")(catchStringOpsMethods(qualifier, name, args))
}

private def catchStringOpsMethods(qual: RuntimeEvaluationTree, name: String, args: Seq[RuntimeEvaluationTree]) =
for {
so <- findClass("StringOps")
init <- findBestMethodBySignedName(so.`type`, "<init>", Seq(qual.`type`))
initTree = NewInstance(CallStaticMethod(init, Seq(qual), so.`type`))
_ = logger.debug(
s"Looking for method $name on ${initTree.`type`} with args: ${args.map(_.`type`).mkString(", ")}"
)
stringOpMethod <- findBestMethodBySignedName(initTree.`type`, name, args.map(_.`type`))
_ = logger.debug(s"Found method $stringOpMethod")
} yield CallInstanceMethod(stringOpMethod, args, initTree)

private def findStaticMethodBySignedName(
qualifier: jdi.ReferenceType,
name: String,
args: Seq[RuntimeEvaluationTree]
): Validation[CallMethod] =
): Validation[CallStaticMethod] =
if (!args.isEmpty)
findMethodBySignedName(qualifier, name, args.map(_.`type`)).flatMap(asStaticMethod(_, args, qualifier))
findBestMethodBySignedName(qualifier, name, args.map(_.`type`)).flatMap(asStaticMethod(_, args, qualifier))
else findZeroArgStaticMethod(qualifier, name)

private def extractCapture(cls: RuntimeClass): Option[RuntimeEvaluationTree] =
Expand All @@ -575,7 +574,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
.forall { init =>
init.argumentTypeNames.asScala.headOption
.exists { argType =>
val suffix = argType.stripSuffix("$") + "$"
val suffix = toModuleName(argType)
tpe.name.startsWith(suffix) && tpe.name.size > suffix.size
}
}
Expand Down Expand Up @@ -722,7 +721,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
else {
val objectName = NameTransformer.scalaClassName(tpe.name).stripSuffix("$")
asReference(qualifier.`type`)
.flatMap(zeroArgMethodByName(_, objectName))
.flatMap(findBestMethodBySignedName(_, objectName, Seq()))
.map(m => preEvaluate(NestedModule(tpe, CallInstanceMethod(m, Seq.empty, qualifier))))
}

Expand Down Expand Up @@ -756,7 +755,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
method: jdi.Method,
args: Seq[RuntimeEvaluationTree],
qualifier: jdi.ReferenceType
): Validation[CallMethod] =
): Validation[CallStaticMethod] =
if (method.isStatic) Valid(CallStaticMethod(method, args, qualifier))
else Recoverable(s"Cannot access instance method ${method.name} from static context")

Expand Down
Loading
Loading