when assigning parameter to function, the returned typevar substitutions must take into account any typevar recomputations

This commit is contained in:
Joeri Exelmans 2025-05-18 09:58:11 +02:00
parent 8266e59b94
commit 248d8ddef1
2 changed files with 25 additions and 9 deletions

View file

@ -68,10 +68,12 @@ export const mergeTwoWay = (m1, m2) => {
checkConflict(m1, m2); checkConflict(m1, m2);
// checkConflict(m2, m1); // <- don't think this is necessary... // checkConflict(m2, m1); // <- don't think this is necessary...
// actually merge // actually merge
let stable = false; let remaining = 2;
while (!stable) { while (remaining > 0) {
// notice we swap m2 and m1, so the rewriting can happen both ways: // notice we swap m2 and m1, so the rewriting can happen both ways:
let stable;
[stable, m2, m1] = mergeOneWay(m1, m2); [stable, m2, m1] = mergeOneWay(m1, m2);
remaining -= stable;
} }
const result = new Map([...m1, ...m2]); const result = new Map([...m1, ...m2]);
// console.log("mergeTwoWay result =", result); // console.log("mergeTwoWay result =", result);
@ -187,8 +189,7 @@ export const substitute = (type, substitutions, stack=[]) => {
}; };
export const assignFn = (funType, paramType) => { export const assignFn = (funType, paramType) => {
const [outType] = assignFnSubstitutions(funType, paramType); return assignFnSubstitutions(funType, paramType)[2];
return outType;
}; };
// same as above, but also returns the substitutions that took place // same as above, but also returns the substitutions that took place
@ -197,12 +198,13 @@ export const assignFnSubstitutions = (funType, paramType) => {
throw new NotAFunctionError(`${prettyT(funType)} is not a function type!`); throw new NotAFunctionError(`${prettyT(funType)} is not a function type!`);
} }
const [[refunType, funS], [reparamType, paramS]] = recomputeTypeVarsSubstitutions([funType, paramType]); const [[refunType, funS], [reparamType, paramS]] = recomputeTypeVarsSubstitutions([funType, paramType]);
const recomputationSubstitutions = mergeTwoWay(funS, paramS);
const [inType, outType] = refunType.params.map(p => p(refunType)); const [inType, outType] = refunType.params.map(p => p(refunType));
const {substitutions} = __unify(inType, reparamType); const {type: newInType, substitutions} = __unify(inType, reparamType);
// console.log(substitutions, prettyT(outType)); const inTypeSubst = mergeTwoWay(substitutions, recomputationSubstitutions);
const substitutedFnType = substitute(outType, substitutions); const substitutedFnType = substitute(outType, substitutions);
const computedOutType = recomputeTypeVars([substitutedFnType])[0]; const [computedOutType, outSubst] = recomputeTypeVarsSubstitutions([substitutedFnType])[0];
return [computedOutType, substitutions]; return [newInType, inTypeSubst, computedOutType, outSubst];
}; };
// Ensures that no type variables overlap // Ensures that no type variables overlap

View file

@ -1,7 +1,8 @@
import assert from "node:assert"; import assert from "node:assert";
import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; import { assignFn, assignFnSubstitutions, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js";
import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
import { prettyT } from "../lib/util/pretty.js"; import { prettyT } from "../lib/util/pretty.js";
import { TYPE_VARS, UNBOUND_SYMBOLS } from "../lib/primitives/typevars.js";
const mkType = getDefaultTypeParser(); const mkType = getDefaultTypeParser();
@ -65,3 +66,16 @@ assert.throws(
// expected error // expected error
UnifyError, UnifyError,
); );
const [inType, inSubst, outType, outSubst] = assignFnSubstitutions(
mkType("Int -> Int"),
mkType("b"),
);
assert.equal(prettyT(inType), "Int");
assert.equal(prettyT(outType), "Int");
assert.equal(inSubst.size, 1);
assert.equal(prettyT(
inSubst.get(UNBOUND_SYMBOLS[1]) // b
), "Int")
assert.equal(outSubst.size, 0);