diff --git a/lib/generics/generics.js b/lib/generics/generics.js index 343c765..10733c2 100644 --- a/lib/generics/generics.js +++ b/lib/generics/generics.js @@ -68,10 +68,12 @@ export const mergeTwoWay = (m1, m2) => { checkConflict(m1, m2); // checkConflict(m2, m1); // <- don't think this is necessary... // actually merge - let stable = false; - while (!stable) { + let remaining = 2; + while (remaining > 0) { // notice we swap m2 and m1, so the rewriting can happen both ways: + let stable; [stable, m2, m1] = mergeOneWay(m1, m2); + remaining -= stable; } const result = new Map([...m1, ...m2]); // console.log("mergeTwoWay result =", result); @@ -187,8 +189,7 @@ export const substitute = (type, substitutions, stack=[]) => { }; export const assignFn = (funType, paramType) => { - const [outType] = assignFnSubstitutions(funType, paramType); - return outType; + return assignFnSubstitutions(funType, paramType)[2]; }; // same as above, but also returns the substitutions that took place @@ -197,12 +198,13 @@ export const assignFnSubstitutions = (funType, paramType) => { throw new NotAFunctionError(`${prettyT(funType)} is not a function type!`); } const [[refunType, funS], [reparamType, paramS]] = recomputeTypeVarsSubstitutions([funType, paramType]); + const recomputationSubstitutions = mergeTwoWay(funS, paramS); const [inType, outType] = refunType.params.map(p => p(refunType)); - const {substitutions} = __unify(inType, reparamType); - // console.log(substitutions, prettyT(outType)); + const {type: newInType, substitutions} = __unify(inType, reparamType); + const inTypeSubst = mergeTwoWay(substitutions, recomputationSubstitutions); const substitutedFnType = substitute(outType, substitutions); - const computedOutType = recomputeTypeVars([substitutedFnType])[0]; - return [computedOutType, substitutions]; + const [computedOutType, outSubst] = recomputeTypeVarsSubstitutions([substitutedFnType])[0]; + return [newInType, inTypeSubst, computedOutType, outSubst]; }; // Ensures that no type variables overlap diff --git a/tests/generics.js b/tests/generics.js index 97fa238..72c040d 100644 --- a/tests/generics.js +++ b/tests/generics.js @@ -1,7 +1,8 @@ import assert from "node:assert"; -import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; +import { assignFn, assignFnSubstitutions, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; import { prettyT } from "../lib/util/pretty.js"; +import { TYPE_VARS, UNBOUND_SYMBOLS } from "../lib/primitives/typevars.js"; const mkType = getDefaultTypeParser(); @@ -65,3 +66,16 @@ assert.throws( // expected error UnifyError, ); + +const [inType, inSubst, outType, outSubst] = assignFnSubstitutions( + mkType("Int -> Int"), + mkType("b"), +); + +assert.equal(prettyT(inType), "Int"); +assert.equal(prettyT(outType), "Int"); +assert.equal(inSubst.size, 1); +assert.equal(prettyT( + inSubst.get(UNBOUND_SYMBOLS[1]) // b +), "Int") +assert.equal(outSubst.size, 0); \ No newline at end of file