mirror of
https://github.com/langchain-ai/kork.git
synced 2026-07-01 22:24:02 -04:00
393 lines
10 KiB
Plaintext
393 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "62549efe",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Retrievers\n",
|
|
"\n",
|
|
"The `Kork` CodeChain accepts retriever interfaces for both examples and foreign functions.\n",
|
|
"\n",
|
|
"The built-in retrievers return the content that they were initialized with as is.\n",
|
|
"\n",
|
|
"You could implement dynamic retrievers that rely on a vector store and return the most relevant examples and foreign functions to achieve the task at hand."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "79fd8e41",
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"jupyter": {
|
|
"outputs_hidden": false
|
|
},
|
|
"nbsphinx": "hidden",
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
},
|
|
"tags": [
|
|
"remove_cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2\n",
|
|
"\n",
|
|
"import sys\n",
|
|
"\n",
|
|
"sys.path.insert(0, \"../\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4fb100da-29d9-442b-b89c-4b60bbc1967b",
|
|
"metadata": {},
|
|
"source": [
|
|
"The code below defines a place holder model for a chat model.\n",
|
|
"Feel free to replace it with a real language model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3c5e0604-b50c-470d-855c-9b5253206a81",
|
|
"metadata": {
|
|
"tags": [
|
|
"hide-input"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Any, List, Optional\n",
|
|
"\n",
|
|
"from langchain.chat_models.base import BaseChatModel\n",
|
|
"from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult\n",
|
|
"from pydantic import Extra\n",
|
|
"\n",
|
|
"\n",
|
|
"class ToyChatModel(BaseChatModel):\n",
|
|
" response: str\n",
|
|
"\n",
|
|
" class Config:\n",
|
|
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
|
"\n",
|
|
" extra = Extra.forbid\n",
|
|
" arbitrary_types_allowed = True\n",
|
|
"\n",
|
|
" def _generate(\n",
|
|
" self, messages: List[BaseMessage], stop: Optional[List[str]] = None\n",
|
|
" ) -> ChatResult:\n",
|
|
" message = AIMessage(content=self.response)\n",
|
|
" generation = ChatGeneration(message=message)\n",
|
|
" return ChatResult(generations=[generation])\n",
|
|
"\n",
|
|
" async def _agenerate(\n",
|
|
" self, messages: List[BaseMessage], stop: Optional[List[str]] = None\n",
|
|
" ) -> Any:\n",
|
|
" \"\"\"Async version of _generate.\"\"\"\n",
|
|
" message = AIMessage(content=self.response)\n",
|
|
" generation = ChatGeneration(message=message)\n",
|
|
" return ChatResult(generations=[generation])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77a7b08f-9d98-475c-b10e-711cbb9d1fbe",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Context Retriever\n",
|
|
"\n",
|
|
"A context retriever (yes, it's poorly named) supports a method that takes the user `query` and returns a list of foreign functions.\n",
|
|
"\n",
|
|
"The built-in retriever returns the information it was initialized with, entirely ignoring the `query`.\n",
|
|
"\n",
|
|
"Sub-class from the abstract interface to provide your own implementation!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "32df6b1b-edeb-4117-be67-c225db5c7a28",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from kork import SimpleContextRetriever\n",
|
|
"import math"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "a3db1e43-5445-40c7-b11c-6e27c5c3dfca",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"simple_context_retriever = SimpleContextRetriever.from_functions(\n",
|
|
" [math.pow, math.log2, math.log10]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "bcd5a5c6-2b4e-4ef9-842f-992bd8a823a4",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[ExternFunctionDef(name='pow', params=ParamList(params=[Param(name='x', type_='Any'), Param(name='y', type_='Any')]), return_type='Any', implementation=<built-in function pow>, doc_string='Return x**y (x to the power of y).'),\n",
|
|
" ExternFunctionDef(name='log2', params=ParamList(params=[Param(name='x', type_='Any')]), return_type='Any', implementation=<built-in function log2>, doc_string='Return the base 2 logarithm of x.'),\n",
|
|
" ExternFunctionDef(name='log10', params=ParamList(params=[Param(name='x', type_='Any')]), return_type='Any', implementation=<built-in function log10>, doc_string='Return the base 10 logarithm of x.')]"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"simple_context_retriever.retrieve(\"this is the query\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f9a05cc8-579f-4bf7-a8b9-61fe13b336ed",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example Retriever\n",
|
|
"\n",
|
|
"Supports the same `retrieve` method as a context retriever. \n",
|
|
"\n",
|
|
"The built-in example retriever returns the information it was initialized with, entirely ignoring the `query`.\n",
|
|
"\n",
|
|
"Sub-class from the abstract interface to provide your own implementation!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "54d8822b-31f1-4e35-9399-2fb8e09ea91b",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from kork import c_, r_, AstPrinter, SimpleExampleRetriever"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "73a91b95-25a0-46c3-be2f-3810152971d6",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"language_name = \"RollingMeow\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "e0f8176e-6722-4306-8b6e-8d223f113b0e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"retriever = SimpleExampleRetriever.from_programs(\n",
|
|
" language_name,\n",
|
|
" [\n",
|
|
" (\"2**5\", r_(c_(math.pow, 2, 5))),\n",
|
|
" (\"take the log base 2 of 2\", r_(c_(math.log2, 2))),\n",
|
|
" ],\n",
|
|
" AstPrinter(),\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "d99b3cd7-3ecc-4c8b-b99c-630b2b717b3a",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[('2**5', '```RollingMeow\\nvar result = pow(2, 5)\\n```'),\n",
|
|
" ('take the log base 2 of 2', '```RollingMeow\\nvar result = log2(2)\\n```')]"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"retriever.retrieve(\"[ignore]\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ab4d3808-50ab-4ade-aab7-677e23624cab",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Declare the chain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "26874d63-5460-42c7-a346-5d70691ba02e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from kork import (\n",
|
|
" CodeChain,\n",
|
|
" ast,\n",
|
|
" AstPrinter,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "5e2df25f-169d-4961-9fd9-4b10bb9f07c9",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"chain = CodeChain.from_defaults(\n",
|
|
" llm=ToyChatModel(response=\"MEOW MEOW MEOW MEOW\"), # The LLM to use\n",
|
|
" ast_printer=AstPrinter(), # Knows how to print the AST\n",
|
|
" examples=retriever, # Example programs\n",
|
|
" context=simple_context_retriever,\n",
|
|
" language_name=language_name,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "7d92f7b6-d1f5-41b5-826b-b338edc3831e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"_, few_shot_prompt = chain.prepare_context(query=\"hello\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "06b5f073-bb8c-40a5-a816-52075417dee2",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You are programming in a language called \"RollingMeow\".\n",
|
|
"\n",
|
|
"You are an expert programmer and must follow the instructions below exactly.\n",
|
|
"\n",
|
|
"Your goal is to translate a user query into a corresponding and valid RollingMeow\n",
|
|
"program.\n",
|
|
"\n",
|
|
"You have access to the following external functions:\n",
|
|
"\n",
|
|
"```RollingMeow\n",
|
|
"\n",
|
|
"extern fn pow(x: Any, y: Any) -> Any // Return x**y (x to the power of y).\n",
|
|
"extern fn log2(x: Any) -> Any // Return the base 2 logarithm of x.\n",
|
|
"extern fn log10(x: Any) -> Any // Return the base 10 logarithm of x.\n",
|
|
"```\n",
|
|
"\n",
|
|
"\n",
|
|
"Do not assume that any other functions except for the ones listed above exist.\n",
|
|
"\n",
|
|
"Wrap the program in ```RollingMeow and ``` to indicate the start and end of\n",
|
|
"the program.\n",
|
|
"\n",
|
|
"Store the solution to the query in a variable called \"result\".\n",
|
|
"\n",
|
|
"Here is a sample valid program:\n",
|
|
"\n",
|
|
"```RollingMeow\n",
|
|
"var x = 1 # Assign 1 to the variable x\n",
|
|
"var result = 1 + 2 # Calculate the sum of 1 + 2 and assign to result\n",
|
|
"var result = x # Assign the value of x to result\n",
|
|
"```\n",
|
|
"\n",
|
|
"Guidelines:\n",
|
|
"- Do not use operators, instead invoke appropriate external functions.\n",
|
|
"- Do not declare functions, do not use loops, do not use conditionals.\n",
|
|
"- Solve the problem only using variable declarations and function invocations.\n",
|
|
"\n",
|
|
"Begin!\n",
|
|
"Input: 2**5\n",
|
|
"Output: ```RollingMeow\n",
|
|
"var result = pow(2, 5)\n",
|
|
"```\n",
|
|
"\n",
|
|
"Input: take the log base 2 of 2\n",
|
|
"Output: ```RollingMeow\n",
|
|
"var result = log2(2)\n",
|
|
"```\n",
|
|
"\n",
|
|
"Input: [user input]\n",
|
|
"Output:\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(few_shot_prompt.format_prompt(query=\"[user input]\").to_string())"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|