feat(script): use cache for compiled scripts

This commit is contained in:
Skylot 2023-10-27 20:49:11 +01:00
parent 192a8116f1
commit 41d986bdca
No known key found for this signature in database
GPG Key ID: 47866607B16F25C8
4 changed files with 132 additions and 4 deletions

View File

@ -12,5 +12,8 @@ dependencies {
implementation("io.github.oshai:kotlin-logging-jvm:5.1.0")
// path for scripts cache
implementation("dev.dirs:directories:26")
testImplementation(project(":jadx-core"))
}

View File

@ -0,0 +1,103 @@
package jadx.plugins.script
import dev.dirs.ProjectDirectories
import java.io.File
import java.security.MessageDigest
import kotlin.script.experimental.api.CompiledScript
import kotlin.script.experimental.api.ScriptCompilationConfiguration
import kotlin.script.experimental.api.SourceCode
import kotlin.script.experimental.jvm.CompiledJvmScriptsCache
import kotlin.script.experimental.jvm.impl.KJvmCompiledScript
import kotlin.script.experimental.jvmhost.loadScriptFromJar
import kotlin.script.experimental.jvmhost.saveToJar
class ScriptCache {
private val enableCache = System.getProperty("JADX_SCRIPT_CACHE_ENABLE", "true").equals("true", ignoreCase = true)
fun build(): CompiledJvmScriptsCache {
if (!enableCache) {
return CompiledJvmScriptsCache.NoCache
}
return JadxScriptsCache(getCacheDir())
}
/**
* Same as CompiledScriptJarsCache implementation,
* but remove all previous cache versions for the script with the same path and name.
* This should reduce old cache entries count
*/
class JadxScriptsCache(private val baseCacheDir: File) : CompiledJvmScriptsCache {
override fun get(
script: SourceCode,
scriptCompilationConfiguration: ScriptCompilationConfiguration,
): CompiledScript? {
val cacheDir = hashDir(baseCacheDir, script)
val file = hashFile(cacheDir, script, scriptCompilationConfiguration)
if (!file.exists()) {
return null
}
return file.loadScriptFromJar() ?: run {
// invalidate cache if the script cannot be loaded
cacheDir.deleteRecursively()
null
}
}
override fun store(
compiledScript: CompiledScript,
script: SourceCode,
scriptCompilationConfiguration: ScriptCompilationConfiguration,
) {
val jvmScript = (compiledScript as? KJvmCompiledScript)
?: throw IllegalArgumentException("Unsupported script type ${compiledScript::class.java.name}")
val cacheDir = hashDir(baseCacheDir, script)
val file = hashFile(cacheDir, script, scriptCompilationConfiguration)
cacheDir.deleteRecursively()
cacheDir.mkdirs()
jvmScript.saveToJar(file)
}
}
private fun getCacheDir(): File {
val dirs = ProjectDirectories.from("io.github", "skylot", "jadx")
val cacheBaseDir = File(dirs.cacheDir, "scripts")
cacheBaseDir.mkdirs()
return cacheBaseDir
}
companion object {
private fun hashDir(baseCacheDir: File, script: SourceCode): File {
if (script.name == null && script.locationId == null) {
return File(baseCacheDir, "tmp")
}
val digest = MessageDigest.getInstance("MD5")
digest.add(script.name)
digest.add(script.locationId)
return File(baseCacheDir, digest.digest().toHexString())
}
private fun hashFile(
cacheDir: File,
script: SourceCode,
scriptCompilationConfiguration: ScriptCompilationConfiguration,
): File {
val digest = MessageDigest.getInstance("MD5")
digest.add(script.text)
scriptCompilationConfiguration.notTransientData.entries
.sortedBy { it.key.name }
.forEach {
digest.add(it.key.name)
digest.add(it.value.toString())
}
return File(cacheDir, digest.digest().toHexString() + ".jar")
}
private fun MessageDigest.add(str: String?) {
str?.let { this.update(it.toByteArray()) }
}
private fun ByteArray.toHexString(): String = joinToString("", transform = { "%02x".format(it) })
}
}

View File

@ -12,7 +12,10 @@ import kotlin.script.experimental.api.ScriptDiagnostic.Severity
import kotlin.script.experimental.api.ScriptEvaluationConfiguration
import kotlin.script.experimental.api.SourceCode
import kotlin.script.experimental.api.constructorArgs
import kotlin.script.experimental.host.ScriptingHostConfiguration
import kotlin.script.experimental.host.toScriptSource
import kotlin.script.experimental.jvm.compilationCache
import kotlin.script.experimental.jvm.jvm
import kotlin.script.experimental.jvmhost.BasicJvmScriptingHost
import kotlin.script.experimental.jvmhost.createJvmCompilationConfigurationFromTemplate
import kotlin.script.experimental.jvmhost.createJvmEvaluationConfigurationFromTemplate
@ -23,17 +26,22 @@ import kotlin.time.toDuration
class ScriptEval {
companion object {
val scriptingHost = BasicJvmScriptingHost()
val scriptingHost = BasicJvmScriptingHost(
baseHostConfiguration = ScriptingHostConfiguration {
jvm {
compilationCache(ScriptCache().build())
}
},
)
val compileConf = createJvmCompilationConfigurationFromTemplate<JadxScriptTemplate>()
private val baseEvalConf = createJvmEvaluationConfigurationFromTemplate<JadxScriptTemplate>()
private fun buildEvalConf(scriptData: JadxScriptData): ScriptEvaluationConfiguration {
return ScriptEvaluationConfiguration(baseEvalConf) {
private fun buildEvalConf(scriptData: JadxScriptData) =
ScriptEvaluationConfiguration(baseEvalConf) {
constructorArgs(scriptData)
}
}
}
fun process(init: JadxPluginContext, scriptOptions: JadxScriptAllOptions): List<JadxScriptData> {

View File

@ -3,14 +3,28 @@ package jadx.plugins.script
import jadx.api.JadxArgs
import jadx.api.JadxDecompiler
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import java.io.File
import kotlin.system.measureTimeMillis
import kotlin.time.DurationUnit
import kotlin.time.toDuration
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class JadxScriptPluginTest {
@BeforeAll
fun disableCache() {
System.setProperty("JADX_SCRIPT_CACHE_ENABLE", "false")
}
@AfterAll
fun clear() {
System.clearProperty("JADX_SCRIPT_CACHE_ENABLE")
}
@Test
fun integrationTest() {
val args = JadxArgs()