Files
kork/tests/test_chain.py
Eugene Yurtsev 6d4dc1b23f Use <code> block, add docs (#3)
* Swap to encoding code block in `<code>` tags which seem more robust.
* Add more documentation.
2023-05-05 10:22:17 -04:00

112 lines
4.1 KiB
Python

from kork import AstPrinter, CodeChain, SimpleContextRetriever, run_interpreter
from kork.examples import SimpleExampleRetriever
from kork.exceptions import KorkRunTimeException, LLMParseException
from kork.parser import parse # type: ignore
from .utils import ToyChatModel
def test_code_chain() -> None:
"""Test the code chain."""
example_retriever = SimpleExampleRetriever(examples=[("Add 1 and 2", "add_(1, 2)")])
llm = ToyChatModel(response="<code>var x = 1</code>")
chain = CodeChain(
llm=llm,
retriever=SimpleContextRetriever(),
interpreter=run_interpreter,
ast_printer=AstPrinter(),
example_retriever=example_retriever,
)
response = chain(inputs={"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
env = response.pop("environment")
assert response == {
"query": "blah",
"code": "var x = 1",
"errors": [],
"raw": "<code>var x = 1</code>",
}
assert env.get_symbol("x") == 1
def test_bad_program() -> None:
"""Test the code chain."""
example_retriever = SimpleExampleRetriever(examples=[("Add 1 and 2", "add_(1, 2)")])
llm = ToyChatModel(response="<code>\nINVALID PROGRAM\n</code>")
chain = CodeChain(
llm=llm,
retriever=SimpleContextRetriever(),
interpreter=run_interpreter,
ast_printer=AstPrinter(),
example_retriever=example_retriever,
)
response = chain(inputs={"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "<code>\nINVALID PROGRAM\n</code>"
assert response["code"] == "\nINVALID PROGRAM\n"
assert response["environment"] is None
assert response["errors"] is not None
assert isinstance(response["errors"][0], KorkRunTimeException)
def test_llm_output_missing_program() -> None:
"""Test the code chain with llm response that's missing code."""
llm = ToyChatModel(response="oops.")
example_retriever = SimpleExampleRetriever(examples=[("Add 1 and 2", "add_(1, 2)")])
chain = CodeChain(
llm=llm,
retriever=SimpleContextRetriever(),
interpreter=run_interpreter,
ast_printer=AstPrinter(),
example_retriever=example_retriever,
)
response = chain(inputs={"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["raw"] == "oops."
assert response["code"] == ""
assert response["environment"] is None
errors = response["errors"]
assert len(errors) == 1
assert isinstance(response["errors"][0], LLMParseException)
def test_from_defaults_instantiation() -> None:
"""Test from default instantiation."""
llm = ToyChatModel(response="<code>\nvar x = 1;\n</code>")
chain = CodeChain.from_defaults(llm=llm)
response = chain(inputs={"query": "blah"})
# Why does the chain return a `query` key?
assert sorted(response) == ["code", "environment", "errors", "query", "raw"]
assert response["environment"].get_symbol("x") == 1
# Instantiate with sequence of examples and foreign functions
def _do() -> int:
"""Do something"""
return 1
# Verify that we can instantiate
examples = [("Add 1 and 2", parse("add_(1, 2)"))]
chain2 = CodeChain.from_defaults(
llm=llm, examples=examples, context=[_do], language_name="kork"
)
assert isinstance(chain2.context_retriever, SimpleContextRetriever)
external_func_defs = chain2.context_retriever.retrieve("[anything]")
# Check foreign funcs
assert len(external_func_defs) == 1
assert external_func_defs[0].name == "_do"
# Check examples
assert isinstance(chain2.example_retriever, SimpleExampleRetriever)
assert chain2.example_retriever.retrieve("[anything]") == [
# Please note that an input formatted was applied by default to the example
("```text\nAdd 1 and 2\n```", "<code>add_(1, 2)</code>")
]