Type inference for recursive let...in... working. Update factorial example.

This commit is contained in:
Joeri Exelmans 2025-05-26 14:59:17 +02:00
parent d000839878
commit 2279c54229
5 changed files with 62 additions and 27 deletions

View file

@ -141,6 +141,6 @@ export const inc: ExprBlockState = {"kind":"let","focus":false,"inner":{"kind":"
export const emptySet: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"","value":{"kind":"gibberish"},"focus":true}}; export const emptySet: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"","value":{"kind":"gibberish"},"focus":true}};
export const factorial: ExprBlockState = {"kind":"lambda","paramName":"factorial","focus":true,"expr":{"kind":"lambda","paramName":"n","focus":true,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"leqZero","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"input","text":"1","value":{"kind":"literal","type":"Int"},"focus":false}}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"mulInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":true}},"input":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"},"focus":true},"input":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"addInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"input","text":"-1","value":{"kind":"literal","type":"Int"},"focus":false}}}}}}}}; export const factorial: ExprBlockState = {"kind":"let","name":"factorial","focus":true,"value":{"kind":"lambda","paramName":"n","focus":true,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"leqZero","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"input","text":"1","value":{"kind":"literal","type":"Int"},"focus":false}}},"input":{"kind":"lambda","paramName":"_","focus":false,"expr":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"mulInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":true}},"input":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"},"focus":true},"input":{"kind":"call","fn":{"kind":"call","fn":{"kind":"input","text":"addInt","value":{"kind":"name"},"focus":false},"input":{"kind":"input","text":"n","value":{"kind":"name"},"focus":false}},"input":{"kind":"input","text":"-1","value":{"kind":"literal","type":"Int"},"focus":false}}}}}}},"inner":{"kind":"call","fn":{"kind":"input","text":"factorial","value":{"kind":"name"}},"input":{"kind":"input","text":"5","value":{"kind":"literal","type":"Int"}}}};
export const setOfListOfBool: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"call","fn":{"kind":"input","text":"compareLists","value":{"kind":"name"}},"input":{"kind":"input","text":"compareDoubles","value":{"kind":"name"}}}}; export const setOfListOfBool: ExprBlockState = {"kind":"call","fn":{"kind":"input","text":"set.emptySet","value":{"kind":"name"},"focus":false},"input":{"kind":"call","fn":{"kind":"input","text":"compareLists","value":{"kind":"name"}},"input":{"kind":"input","text":"compareDoubles","value":{"kind":"name"}}}};

View file

@ -33,7 +33,6 @@ export function LetInBlock(props: LetInBlockProps) {
} }
function DeclColumns({state, setState, score, typeInfo}) { function DeclColumns({state, setState, score, typeInfo}) {
const env = useContext(EnvContext);
const globalContext = useContext(GlobalContext); const globalContext = useContext(GlobalContext);
const setInner = callback => setState(state => ({...state, inner: callback(state.inner)})); const setInner = callback => setState(state => ({...state, inner: callback(state.inner)}));
@ -78,7 +77,6 @@ function DeclColumns({state, setState, score, typeInfo}) {
} }
function InnerMost({state, setState, score, typeInfo}) { function InnerMost({state, setState, score, typeInfo}) {
const env = useContext(EnvContext);
const globalContext = useContext(GlobalContext); const globalContext = useContext(GlobalContext);
const setInner = callback => setState(state => ({...state, inner: callback(state.inner)})); const setInner = callback => setState(state => ({...state, inner: callback(state.inner)}));
const onCancel = () => setState(state => state.value); const onCancel = () => setState(state => state.value);

View file

@ -1,10 +1,11 @@
import { Double, eqType, fnType, IncompatibleTypesError, Int, mergeSubstitutionsN, occurring, prettyS, prettySS, prettyT, recomputeTypeVars, substitute, SubstitutionCycle, trie, TYPE_VARS, UNBOUND_SYMBOLS, unify } from "dope2"; import { Double, eqType, fnType, IncompatibleTypesError, Int, mergeSubstitutionsN, occurring, recomputeTypeVars, substitute, SubstitutionCycle, trie, TYPE_VARS, UNBOUND_SYMBOLS, unify } from "dope2";
import type { CallBlockState } from "../component/expr/CallBlock"; import type { CallBlockState } from "../component/expr/CallBlock";
import type { ExprBlockState } from "../component/expr/ExprBlock"; import type { ExprBlockState } from "../component/expr/ExprBlock";
import type { InputBlockState } from "../component/expr/InputBlock"; import type { InputBlockState } from "../component/expr/InputBlock";
import type { LambdaBlockState } from "../component/expr/LambdaBlock"; import type { LambdaBlockState } from "../component/expr/LambdaBlock";
import type { LetInBlockState } from "../component/expr/LetInBlock"; import type { LetInBlockState } from "../component/expr/LetInBlock";
import { memoize } from "../util/memoize";
export interface Environment { export interface Environment {
names: any; names: any;
@ -48,7 +49,7 @@ export interface TypeInfoLambda extends TypeInfoCommon {
export type TypeInfo = TypeInfoInput | TypeInfoCall | TypeInfoLet | TypeInfoLambda; export type TypeInfo = TypeInfoInput | TypeInfoCall | TypeInfoLet | TypeInfoLambda;
export function inferType(s: ExprBlockState, env: Environment): TypeInfo { export const inferType = memoize(function inferType(s: ExprBlockState, env: Environment): TypeInfo {
if (s.kind === "input") { if (s.kind === "input") {
return inferTypeInput(s, env); return inferTypeInput(s, env);
} }
@ -61,9 +62,9 @@ export function inferType(s: ExprBlockState, env: Environment): TypeInfo {
else { // (s.kind === "lambda") else { // (s.kind === "lambda")
return inferTypeLambda(s, env); return inferTypeLambda(s, env);
} }
} });
export function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoInput { export const inferTypeInput = memoize(function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoInput {
if (s.value.kind === "literal") { if (s.value.kind === "literal") {
const type = { const type = {
Int: Int, Int: Int,
@ -99,9 +100,9 @@ export function inferTypeInput(s: InputBlockState, env: Environment): TypeInfoIn
newEnv, newEnv,
err: new Error("Gibberish"), err: new Error("Gibberish"),
} }
} });
export function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall { export const inferTypeCall = memoize(function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall {
const fnTypeInfo = inferType(s.fn, env); const fnTypeInfo = inferType(s.fn, env);
const inputEnv = fnTypeInfo.newEnv; const inputEnv = fnTypeInfo.newEnv;
const inputTypeInfo = inferType(s.input, inputEnv); const inputTypeInfo = inferType(s.input, inputEnv);
@ -159,46 +160,55 @@ export function inferTypeCall(s: CallBlockState, env: Environment): TypeInfoCall
} }
throw e; throw e;
} }
} });
export function inferTypeLet(s: LetInBlockState, env: Environment): TypeInfoLet { export const inferTypeLet = memoize(function inferTypeLet(s: LetInBlockState, env: Environment): TypeInfoLet {
const valTypeInfo = inferType(s.value, env); const recursiveTypeInfo = iterateRecursiveType(s.name, s.value, env);
// to eval the 'inner' expr, we only need to add our parameter to the environment:
const innerEnv = { const innerEnv = {
names: trie.insert(env.names)(s.name)({kind: "value", t: valTypeInfo.type}), names: trie.insert(env.names)(s.name)({kind: "value", t: recursiveTypeInfo.paramType}),
typevars: env.typevars, typevars: env.typevars,
}; };
const innerTypeInfo = inferType(s.inner, innerEnv); const innerTypeInfo = inferType(s.inner, innerEnv);
return { return {
kind: "let", kind: "let",
type: innerTypeInfo.type, type: innerTypeInfo.type,
value: recursiveTypeInfo.inner,
subs: innerTypeInfo.subs, subs: innerTypeInfo.subs,
newEnv: env, newEnv: innerTypeInfo.newEnv,
value: valTypeInfo,
inner: innerTypeInfo, inner: innerTypeInfo,
innerEnv, innerEnv,
}; };
} });
export const inferTypeLambda = memoize(function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfoLambda {
const recursiveTypeInfo = iterateRecursiveType(s.paramName, s.expr, env);
return {
kind: "lambda",
type: fnType(_ => recursiveTypeInfo.paramType)(_ => recursiveTypeInfo.inner.type),
...recursiveTypeInfo,
};
});
export function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfoLambda { // Given a named value whose type we know nothing about, and an expression that computes the value (which may recursively contain the value), compute the type of the value.
// Why? Both lambda functions and let-expressions can refer to themselves recursively. To infer their type, we need to recompute the type and feed it back to itself until some fixed point is reached.
function iterateRecursiveType(paramName: string, expr: ExprBlockState, env: Environment) {
let [paramType] = typeUnknown(env); let [paramType] = typeUnknown(env);
const paramTypeVar = paramType.symbol; const paramTypeVar = paramType.symbol;
let iterations = 1; let iterations = 1;
while (true) { while (true) {
const innerEnv = { const innerEnv = {
names: trie.insert(env.names)(s.paramName)({kind: "unknown", t: paramType}), names: trie.insert(env.names)(paramName)({kind: "unknown", t: paramType}),
typevars: env.typevars.union(occurring(paramType) as Set<string>), typevars: env.typevars.union(occurring(paramType) as Set<string>),
}; };
const innerTypeInfo = inferType(s.expr, innerEnv); const innerTypeInfo = inferType(expr, innerEnv);
const subsWithoutPType = new Map(innerTypeInfo.subs); const subsWithoutPType = new Map(innerTypeInfo.subs);
subsWithoutPType.delete(paramTypeVar); subsWithoutPType.delete(paramTypeVar);
const inferredPType = substitute(paramType, innerTypeInfo.subs, []); const inferredPType = substitute(paramType, innerTypeInfo.subs, []);
const [inferredPType2, newEnv] = rewriteInferredType(inferredPType, env); const [inferredPType2, newEnv] = rewriteInferredType(inferredPType, env);
if (eqType(inferredPType2)(paramType)) { if (eqType(inferredPType2)(paramType)) {
return { return {
kind: "lambda",
type: fnType(_ => paramType)(_ => innerTypeInfo.type),
subs: subsWithoutPType, subs: subsWithoutPType,
newEnv, newEnv,
paramType, paramType,
@ -206,8 +216,8 @@ export function inferTypeLambda(s: LambdaBlockState, env: Environment): TypeInfo
innerEnv, innerEnv,
}; };
} }
if ((iterations++) == 10) { if ((iterations++) == 100) {
throw new Error("too many iterations!"); throw new Error("too many iterations! something's wrong!");
} }
// console.log("-----------------", iterations); // console.log("-----------------", iterations);
// console.log("paramType:", prettyT(paramType)); // console.log("paramType:", prettyT(paramType));

27
src/util/memoize.ts Normal file
View file

@ -0,0 +1,27 @@
// export function memoize<R>(originalFunction: (...args: any[]) => R): (...args: any[]) => R {
// let cache;
// const cacheM = new Map();
// return function(...args: any[]): any {
// // console.log('memoized', originalFunction.name, args);
// if (args.length === 0) {
// if (cache === undefined) {
// cache = originalFunction();
// }
// return cache;
// }
// else {
// let result = cacheM.get(args[0]);
// if (result === undefined) {
// result = memoize((...rest) => originalFunction(args[0], ...rest));
// cacheM.set(args[0], result);
// }
// return result(...args.slice(1));
// }
// }
// }
// memoization, implemented as above, only seems to make things slower :(
export function memoize<R>(originalFunction: (...args: any[]) => R): (...args: any[]) => R {
return originalFunction;
}