From 2279c54229fdcf1fee37172f8c90b4fe0945d2f4 Mon Sep 17 00:00:00 2001 From: Joeri Exelmans Date: Mon, 26 May 2025 14:59:17 +0200 Subject: [PATCH] Type inference for recursive let...in... working. Update factorial example. --- src/component/app/App.tsx | 8 ++--- src/component/app/configurations.ts | 2 +- src/component/expr/LetInBlock.tsx | 2 -- src/eval/infer_type.ts | 50 +++++++++++++++++------------ src/util/memoize.ts | 27 ++++++++++++++++ 5 files changed, 62 insertions(+), 27 deletions(-) create mode 100644 src/util/memoize.ts diff --git a/src/component/app/App.tsx b/src/component/app/App.tsx index 4572456..b91fd19 100644 --- a/src/component/app/App.tsx +++ b/src/component/app/App.tsx @@ -162,10 +162,10 @@ export function App() { FACTORY RESET diff --git a/src/component/app/configurations.ts b/src/component/app/configurations.ts index 1878b6e..32befcb 100644 --- a/src/component/app/configurations.ts +++ b/src/component/app/configurations.ts @@ -141,6 +141,6 @@ export const inc: ExprBlockState = {"kind":"let","focus":false,"inner":{"kind":" export const emptySet: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"","value":{"kind":"gibberish"},"focus":true}}; -export const factorial: ExprBlockState = {"kind":"lambda","paramName":"factorial","focus":true,"expr":{"kind":"lambda","paramName":"n","focus":true,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"leqZero","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"input","text":"1","value":{"kind":"literal","type":"Int"},"focus":false}}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"mulInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":true}},"input":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"},"focus":true},"input":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"addInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"input","text":"-1","value":{"kind":"literal","type":"Int"},"focus":false}}}}}}}}; +export const factorial: ExprBlockState = {"kind":"let","name":"factorial","focus":true,"value":{"kind":"lambda","paramName":"n","focus":true,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"leqZero","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"input","text":"1","value":{"kind":"literal","type":"Int"},"focus":false}}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"mulInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":true}},"input":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"},"focus":true},"input":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"addInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"input","text":"-1","value":{"kind":"literal","type":"Int"},"focus":false}}}}}}},"inner":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"}},"input":{"kind":"input","text":"5","value":{"kind":"literal","type":"Int"}}}}; export const setOfListOfBool: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"call","fn":{"kind":"input","text":"compareLists","value":{"kind":"name"}},"input":{"kind":"input","text":"compareDoubles","value":{"kind":"name"}}}}; \ No newline at end of file diff --git a/src/component/expr/LetInBlock.tsx b/src/component/expr/LetInBlock.tsx index a866f09..ab7b787 100644 --- a/src/component/expr/LetInBlock.tsx +++ b/src/component/expr/LetInBlock.tsx @@ -33,7 +33,6 @@ export function LetInBlock(props: LetInBlockProps) { } function DeclColumns({state, setState, score, typeInfo}) { - const env = useContext(EnvContext); const globalContext = useContext(GlobalContext); const setInner = callback => setState(state => ({...state, inner: callback(state.inner)})); @@ -78,7 +77,6 @@ function DeclColumns({state, setState, score, typeInfo}) { } function InnerMost({state, setState, score, typeInfo}) { - const env = useContext(EnvContext); const globalContext = useContext(GlobalContext); const setInner = callback => setState(state => ({...state, inner: callback(state.inner)})); const onCancel = () => setState(state => state.value); diff --git a/src/eval/infer_type.ts b/src/eval/infer_type.ts index 1065765..41513ee 100644 --- a/src/eval/infer_type.ts +++ b/src/eval/infer_type.ts @@ -1,10 +1,11 @@ -import { Double, eqType, fnType, IncompatibleTypesError, Int, mergeSubstitutionsN, occurring, prettyS, prettySS, prettyT, recomputeTypeVars, substitute, SubstitutionCycle, trie, TYPE_VARS, UNBOUND_SYMBOLS, unify } from "dope2"; +import { Double, eqType, fnType, IncompatibleTypesError, Int, mergeSubstitutionsN, occurring, recomputeTypeVars, substitute, SubstitutionCycle, trie, TYPE_VARS, UNBOUND_SYMBOLS, unify } from "dope2"; import type { CallBlockState } from "../component/expr/CallBlock"; import type { ExprBlockState } from "../component/expr/ExprBlock"; import type { InputBlockState } from "../component/expr/InputBlock"; import type { LambdaBlockState } from "../component/expr/LambdaBlock"; import type { LetInBlockState } from "../component/expr/LetInBlock"; +import { memoize } from "../util/memoize"; export interface Environment { names: any; @@ -48,7 +49,7 @@ export interface TypeInfoLambda extends TypeInfoCommon { export type TypeInfo = TypeInfoInput | TypeInfoCall | TypeInfoLet | TypeInfoLambda; -export function inferType(s: ExprBlockState, env: Environment): TypeInfo { +export const inferType = memoize(function inferType(s: ExprBlockState, env: Environment): TypeInfo { if (s.kind === "input") { return inferTypeInput(s, env); } @@ -61,9 +62,9 @@ export function inferType(s: ExprBlockState, env: Environment): TypeInfo { else { // (s.kind === "lambda") return inferTypeLambda(s, env); } -} +}); -export function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoInput { +export const inferTypeInput = memoize(function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoInput { if (s.value.kind === "literal") { const type = { Int: Int, @@ -99,9 +100,9 @@ export function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoIn newEnv, err: new Error("Gibberish"), } -} +}); -export function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall { +export const inferTypeCall = memoize(function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall { const fnTypeInfo = inferType(s.fn, env); const inputEnv = fnTypeInfo.newEnv; const inputTypeInfo = inferType(s.input, inputEnv); @@ -159,46 +160,55 @@ export function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall } throw e; } -} +}); -export function inferTypeLet(s: LetInBlockState, env: Environment): TypeInfoLet { - const valTypeInfo = inferType(s.value, env); +export const inferTypeLet = memoize(function inferTypeLet(s: LetInBlockState, env: Environment): TypeInfoLet { + const recursiveTypeInfo = iterateRecursiveType(s.name, s.value, env); + // to eval the 'inner' expr, we only need to add our parameter to the environment: const innerEnv = { - names: trie.insert(env.names)(s.name)({kind: "value", t: valTypeInfo.type}), + names: trie.insert(env.names)(s.name)({kind: "value", t: recursiveTypeInfo.paramType}), typevars: env.typevars, }; const innerTypeInfo = inferType(s.inner, innerEnv); return { kind: "let", type: innerTypeInfo.type, + value: recursiveTypeInfo.inner, subs: innerTypeInfo.subs, - newEnv: env, - value: valTypeInfo, + newEnv: innerTypeInfo.newEnv, inner: innerTypeInfo, innerEnv, }; -} +}); +export const inferTypeLambda = memoize(function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfoLambda { + const recursiveTypeInfo = iterateRecursiveType(s.paramName, s.expr, env); + return { + kind: "lambda", + type: fnType(_ => recursiveTypeInfo.paramType)(_ => recursiveTypeInfo.inner.type), + ...recursiveTypeInfo, + }; +}); -export function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfoLambda { +// Given a named value whose type we know nothing about, and an expression that computes the value (which may recursively contain the value), compute the type of the value. +// Why? Both lambda functions and let-expressions can refer to themselves recursively. To infer their type, we need to recompute the type and feed it back to itself until some fixed point is reached. +function iterateRecursiveType(paramName: string, expr: ExprBlockState, env: Environment) { let [paramType] = typeUnknown(env); const paramTypeVar = paramType.symbol; let iterations = 1; while (true) { const innerEnv = { - names: trie.insert(env.names)(s.paramName)({kind: "unknown", t: paramType}), + names: trie.insert(env.names)(paramName)({kind: "unknown", t: paramType}), typevars: env.typevars.union(occurring(paramType) as Set), }; - const innerTypeInfo = inferType(s.expr, innerEnv); + const innerTypeInfo = inferType(expr, innerEnv); const subsWithoutPType = new Map(innerTypeInfo.subs); subsWithoutPType.delete(paramTypeVar); const inferredPType = substitute(paramType, innerTypeInfo.subs, []); const [inferredPType2, newEnv] = rewriteInferredType(inferredPType, env); if (eqType(inferredPType2)(paramType)) { return { - kind: "lambda", - type: fnType(_ => paramType)(_ => innerTypeInfo.type), subs: subsWithoutPType, newEnv, paramType, @@ -206,8 +216,8 @@ export function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfo innerEnv, }; } - if ((iterations++) == 10) { - throw new Error("too many iterations!"); + if ((iterations++) == 100) { + throw new Error("too many iterations! something's wrong!"); } // console.log("-----------------", iterations); // console.log("paramType:", prettyT(paramType)); diff --git a/src/util/memoize.ts b/src/util/memoize.ts new file mode 100644 index 0000000..8eb2549 --- /dev/null +++ b/src/util/memoize.ts @@ -0,0 +1,27 @@ +// export function memoize(originalFunction: (...args: any[]) => R): (...args: any[]) => R { +// let cache; +// const cacheM = new Map(); +// return function(...args: any[]): any { +// // console.log('memoized', originalFunction.name, args); +// if (args.length === 0) { +// if (cache === undefined) { +// cache = originalFunction(); +// } +// return cache; +// } +// else { +// let result = cacheM.get(args[0]); +// if (result === undefined) { +// result = memoize((...rest) => originalFunction(args[0], ...rest)); +// cacheM.set(args[0], result); +// } +// return result(...args.slice(1)); +// } +// } +// } + +// memoization, implemented as above, only seems to make things slower :( +export function memoize(originalFunction: (...args: any[]) => R): (...args: any[]) => R { + return originalFunction; +} +