Browse Source

struct argument passing

Kajetan Johannes Hammerle 3 years ago
parent
commit
82ce69b5a8

+ 138 - 53
Compiler.c

@@ -208,47 +208,89 @@ static DataType cCallFunction(const char* name) {
     return found->returnType;
 }
 
-static DataType cLoadRef(Variable* v, Operation op, DataType dt) {
+static void cAddOffset(int offset) {
+    if(offset > 0) {
+        cAddOperation(OP_PUSH_INT);
+        cAddInt(offset);
+        cAddOperation(OP_ADD_INT);
+    }
+}
+
+static DataType cLoadRef(Variable* v, Operation op, DataType dt, int offset) {
     cAddOperation(OP_LOAD_INT);
     cAddInt(v->address);
+    cAddOffset(offset);
     cAddOperation(op);
     return dt;
 }
 
-static DataType cLoadVariable(Variable* v) {
-    if(dtCompare(v->type, dtToReference(dtInt()))) {
-        return cLoadRef(v, OP_REF_LOAD_INT, dtInt());
-    } else if(dtCompare(v->type, dtToReference(dtFloat()))) {
-        return cLoadRef(v, OP_REF_LOAD_FLOAT, dtFloat());
-    } else if(dtCompare(v->type, dtToReference(dtBool()))) {
-        return cLoadRef(v, OP_REF_LOAD_BOOL, dtBool());
+static DataType cAddVariable(Operation op, Variable* v) {
+    cAddOperation(op);
+    cAddInt(v->address);
+    return v->type;
+}
+
+static DataType cLoadVariable(Variable* v, Variable* sv) {
+    DataType type = v->type;
+    int offset = 0;
+    if(dtIsStruct(type) && sv->type.type != DT_VOID) {
+        sv->address += v->address;
+        return cLoadVariable(sv, sv);
+    } else if(dtIsStructRef(type)) {
+        type = dtToReference(sv->type);
+        offset = sv->address;
+    }
+    if(dtCompare(type, dtToReference(dtInt()))) {
+        return cLoadRef(v, OP_REF_LOAD_INT, dtInt(), offset);
+    } else if(dtCompare(type, dtToReference(dtFloat()))) {
+        return cLoadRef(v, OP_REF_LOAD_FLOAT, dtFloat(), offset);
+    } else if(dtCompare(type, dtToReference(dtBool()))) {
+        return cLoadRef(v, OP_REF_LOAD_BOOL, dtBool(), offset);
     }
     switch(dtAsInt(v->type)) {
-        case DT_INT: cAddOperation(OP_LOAD_INT); break;
-        case DT_BOOL: cAddOperation(OP_LOAD_BOOL); break;
-        case DT_FLOAT: cAddOperation(OP_LOAD_FLOAT); break;
-        default: cError("cannot load type %s", cGetName(v->type));
+        case DT_INT: return cAddVariable(OP_LOAD_INT, v);
+        case DT_BOOL: return cAddVariable(OP_LOAD_BOOL, v);
+        case DT_FLOAT: return cAddVariable(OP_LOAD_FLOAT, v);
+        case DT_STRUCT:
+            {
+                Struct* st = dtGetStruct(&structs, v->type);
+                if(st == NULL) {
+                    cError("compiler struct error");
+                }
+                int address = v->address;
+                for(int i = 0; i < st->amount; i++) {
+                    Variable v = {st->vars[i].name, st->vars[i].type, address};
+                    cLoadVariable(&v, sv);
+                    address += dtGetSize(v.type, &structs);
+                }
+                return v->type;
+            }
+        default:
+            cError("cannot load type %s", cGetName(v->type));
+            return dtVoid();
     }
-    cAddInt(v->address);
-    return v->type;
 }
 
-static bool cStoreRef(Variable* v, DataType should, DataType dt, Operation op) {
-    if(dtCompare(v->type, dtToReference(should)) && dtCompare(dt, should)) {
+static bool cStoreRef(Variable* v, Variable* sv, DataType should, DataType dt,
+                      Operation op) {
+    DataType type = v->type;
+    int offset = 0;
+    if(dtIsStructRef(type)) {
+        type = dtToReference(sv->type);
+        offset = sv->address;
+    }
+    if(dtCompare(type, dtToReference(should)) && dtCompare(dt, should)) {
         cAddOperation(OP_LOAD_INT);
         cAddInt(v->address);
+        cAddOffset(offset);
         cAddOperation(op);
         return true;
     }
     return false;
 }
 
-static void cStoreVariable(Variable* v, DataType dt, const char* name) {
-    if(cStoreRef(v, dtInt(), dt, OP_REF_STORE_INT) ||
-       cStoreRef(v, dtFloat(), dt, OP_REF_STORE_FLOAT) ||
-       cStoreRef(v, dtBool(), dt, OP_REF_STORE_BOOL)) {
-        return;
-    } else if(!dtCompare(v->type, dt)) {
+static void cStore(Variable* v, DataType dt, const char* name) {
+    if(!dtCompare(v->type, dt)) {
         cInvalidOperation(v->type, dt, name);
     } else if(v->type.reference) {
         cAddOperation(OP_STORE_INT);
@@ -269,6 +311,21 @@ static void cStoreVariable(Variable* v, DataType dt, const char* name) {
     cAddInt(v->address);
 }
 
+static void cStoreVariable(Variable* v, Variable* sv, DataType dt,
+                           const char* name) {
+    if(cStoreRef(v, sv, dtInt(), dt, OP_REF_STORE_INT) ||
+       cStoreRef(v, sv, dtFloat(), dt, OP_REF_STORE_FLOAT) ||
+       cStoreRef(v, sv, dtBool(), dt, OP_REF_STORE_BOOL)) {
+        return;
+    }
+    if(dtIsStruct(v->type) && sv->type.type != DT_VOID) {
+        sv->address += v->address;
+        cStore(sv, dt, name);
+    } else {
+        cStore(v, dt, name);
+    }
+}
+
 static DataType cPostChange(Variable* v, int change, const char* name) {
     if(!dtCompare(v->type, dtInt())) {
         cError("%s needs an int", name);
@@ -285,6 +342,24 @@ static DataType cPostChange(Variable* v, int change, const char* name) {
     return dtInt();
 }
 
+static void cWalkStruct(Variable* v, Variable* sv) {
+    sv->address = 0;
+    sv->name = "";
+    sv->type = dtVoid();
+    if(!cConsumeTokenIf(T_POINT)) {
+        return;
+    }
+    Struct* st = dtGetStruct(&structs, v->type);
+    if(st == NULL) {
+        cError("%s is not a struct but %s", v->name, cGetName(v->type));
+    }
+    cConsumeToken(T_LITERAL);
+    const char* name = cReadString();
+    if(vSearchStruct(sv, &structs, st, name)) {
+        cError("%s has no member %s", v->name, name);
+    }
+}
+
 static DataType cLiteral() {
     const char* literal = cReadString();
     if(cConsumeTokenIf(T_OPEN_BRACKET)) {
@@ -298,12 +373,14 @@ static DataType cLiteral() {
     if(v == NULL) {
         cNotDeclared(literal);
     }
+    Variable sv;
+    cWalkStruct(v, &sv);
     if(cConsumeTokenIf(T_INCREMENT)) {
         return cPostChange(v, 1, "++");
     } else if(cConsumeTokenIf(T_DECREMENT)) {
         return cPostChange(v, -1, "--");
     }
-    return cLoadVariable(v);
+    return cLoadVariable(v, &sv);
 }
 
 static DataType cBracketPrimary() {
@@ -400,8 +477,13 @@ static DataType cPreUnary() {
         if(v == NULL) {
             cNotDeclared(literal);
         }
-        cAddOperation(OP_VAR_REF);
-        cAddInt(v->address);
+        if(v->type.reference) {
+            cAddOperation(OP_LOAD_INT);
+            cAddInt(v->address);
+        } else {
+            cAddOperation(OP_VAR_REF);
+            cAddInt(v->address);
+        }
         return dtToReference(v->type);
     }
     return cPrimary();
@@ -568,11 +650,11 @@ static DataType cExpression() {
     return cOr();
 }
 
-static void cOperationSet(Variable* v, const TypedOp* op) {
-    DataType a = cLoadVariable(v);
+static void cOperationSet(Variable* v, Variable* sv, const TypedOp* op) {
+    DataType a = cLoadVariable(v, sv);
     DataType b = cExpression();
     cAddTypeOperation(a, b, op);
-    cStoreVariable(v, b, "=");
+    cStoreVariable(v, sv, b, "=");
 }
 
 static void cAddPostLineChange(Variable* v, int change, const char* name) {
@@ -596,16 +678,6 @@ static void cDeclareStruct(Struct* st) {
         cDeclared(var);
     }
     vAdd(&vars, var, dtStruct(st), &structs);
-    int varLength = strlen(var);
-    for(int i = 0; i < st->amount; i++) {
-        int length = strlen(st->vars[i].name);
-        char* fullName = malloc(varLength + length + 2);
-        memcpy(fullName, var, varLength);
-        fullName[varLength] = '.';
-        memcpy(fullName + varLength + 1, st->vars[i].name, length + 1);
-        vAdd(&vars, fullName, st->vars[i].type, &structs);
-        free(fullName);
-    }
 }
 
 static void cLineLiteral() {
@@ -626,19 +698,24 @@ static void cLineLiteral() {
     if(v == NULL) {
         cNotDeclared(literal);
     }
+    Variable sv;
+    sv.type = dtVoid();
+    cWalkStruct(v, &sv);
     Token t = cReadTokenAndLine();
     switch(t) {
-        case T_SET: cStoreVariable(v, cExpression(), "="); break;
-        case T_ADD_SET: cOperationSet(v, &TYPED_ADD); break;
-        case T_SUB_SET: cOperationSet(v, &TYPED_SUB); break;
-        case T_MUL_SET: cOperationSet(v, &TYPED_MUL); break;
-        case T_DIV_SET: cOperationSet(v, &TYPED_DIV); break;
-        case T_MOD_SET: cOperationSet(v, &TYPED_MOD); break;
-        case T_BIT_AND_SET: cOperationSet(v, &TYPED_BIT_AND); break;
-        case T_BIT_OR_SET: cOperationSet(v, &TYPED_BIT_OR); break;
-        case T_BIT_XOR_SET: cOperationSet(v, &TYPED_BIT_XOR); break;
-        case T_LEFT_SHIFT_SET: cOperationSet(v, &TYPED_LEFT_SHIFT); break;
-        case T_RIGHT_SHIFT_SET: cOperationSet(v, &TYPED_RIGHT_SHIFT); break;
+        case T_SET: cStoreVariable(v, &sv, cExpression(), "="); break;
+        case T_ADD_SET: cOperationSet(v, &sv, &TYPED_ADD); break;
+        case T_SUB_SET: cOperationSet(v, &sv, &TYPED_SUB); break;
+        case T_MUL_SET: cOperationSet(v, &sv, &TYPED_MUL); break;
+        case T_DIV_SET: cOperationSet(v, &sv, &TYPED_DIV); break;
+        case T_MOD_SET: cOperationSet(v, &sv, &TYPED_MOD); break;
+        case T_BIT_AND_SET: cOperationSet(v, &sv, &TYPED_BIT_AND); break;
+        case T_BIT_OR_SET: cOperationSet(v, &sv, &TYPED_BIT_OR); break;
+        case T_BIT_XOR_SET: cOperationSet(v, &sv, &TYPED_BIT_XOR); break;
+        case T_LEFT_SHIFT_SET: cOperationSet(v, &sv, &TYPED_LEFT_SHIFT); break;
+        case T_RIGHT_SHIFT_SET:
+            cOperationSet(v, &sv, &TYPED_RIGHT_SHIFT);
+            break;
         case T_INCREMENT: cAddPostLineChange(v, 1, "++"); break;
         case T_DECREMENT: cAddPostLineChange(v, -1, "--"); break;
         default: cUnexpectedToken(t);
@@ -787,7 +864,7 @@ static void cDeclare(DataType dt) {
     }
     v = vAdd(&vars, var, dt, &structs);
     cConsumeToken(T_SET);
-    cStoreVariable(v, cExpression(), "=");
+    cStore(v, cExpression(), "=");
 }
 
 static void cAddPreLineChange(int change, const char* name) {
@@ -899,9 +976,7 @@ static void cFunctionCommaOrEnd(Function* f) {
 }
 
 static void cFunctionAddArgument(Function* f, DataType dt) {
-    if(cConsumeTokenIf(T_BIT_AND)) {
-        dt = dtToReference(dt);
-    }
+    dt = cExtendType(dt);
     cConsumeToken(T_LITERAL);
     const char* name = cReadString();
     Variable* v = vSearchScope(&vars, name);
@@ -921,6 +996,16 @@ static void cFunctionArgument(Function* f) {
         case T_INT: cFunctionAddArgument(f, dtInt()); break;
         case T_FLOAT: cFunctionAddArgument(f, dtFloat()); break;
         case T_BOOL: cFunctionAddArgument(f, dtBool()); break;
+        case T_LITERAL:
+            {
+                const char* structName = cReadString();
+                Struct* st = stsSearch(&structs, structName);
+                if(st == NULL) {
+                    cError("struct %s does not exist");
+                }
+                cFunctionAddArgument(f, dtStruct(st));
+                break;
+            }
         default: cUnexpectedToken(t);
     }
 }

+ 31 - 14
DataType.c

@@ -23,15 +23,13 @@ static void dtAppend(const char* s) {
 const char* dtGetName(Structs* sts, DataType dt) {
     typeNameSwap = !typeNameSwap;
     typeNameIndex = 0;
-    if(dt.structId > 0) {
-        dtAppend(sts->data[dt.structId - 1].name);
-    } else {
-        switch(dt.type) {
-            case DT_INT: dtAppend("int"); break;
-            case DT_FLOAT: dtAppend("float"); break;
-            case DT_BOOL: dtAppend("bool"); break;
-            default: dtAppend("unknown");
-        }
+    switch(dt.type) {
+        case DT_INT: dtAppend("int"); break;
+        case DT_FLOAT: dtAppend("float"); break;
+        case DT_BOOL: dtAppend("bool"); break;
+        case DT_STRUCT: dtAppend(sts->data[dt.structId].name); break;
+        case DT_VOID: dtAppend("void"); break;
+        default: dtAppend("unknown");
     }
     for(unsigned int i = 0; i < dt.pointers; i++) {
         dtAppend("*");
@@ -43,16 +41,20 @@ const char* dtGetName(Structs* sts, DataType dt) {
 }
 
 int dtGetSize(DataType dt, Structs* sts) {
-    (void)sts;
     switch(dtAsInt(dt)) {
         case DT_INT: return sizeof(int);
         case DT_FLOAT: return sizeof(float);
         case DT_BOOL: return sizeof(bool);
-        default:
-            if(dt.structId > 0) {
-                return 0;
+        case DT_STRUCT:
+            {
+                int size = 0;
+                Struct* st = sts->data + dt.structId;
+                for(int i = 0; i < st->amount; i++) {
+                    size += dtGetSize(st->vars[i].type, sts);
+                }
+                return size;
             }
-            return sizeof(int);
+        default: return sizeof(int);
     }
 }
 
@@ -109,6 +111,21 @@ bool dtIsArray(DataType dt) {
     return dt.pointers > 0;
 }
 
+bool dtIsStruct(DataType dt) {
+    return dt.type == DT_STRUCT && dt.pointers == 0 && dt.reference == 0;
+}
+
+bool dtIsStructRef(DataType dt) {
+    return dt.type == DT_STRUCT && dt.pointers == 0 && dt.reference == 1;
+}
+
+Struct* dtGetStruct(Structs* sts, DataType dt) {
+    if(dt.type != DT_STRUCT || dt.pointers != 0) {
+        return NULL;
+    }
+    return sts->data + dt.structId;
+}
+
 void stAddVariable(Struct* st, const char* name, DataType type) {
     int index = st->amount;
     st->amount++;

+ 3 - 0
DataType.h

@@ -45,6 +45,9 @@ DataType dtStruct(Struct* st);
 DataType dtToReference(DataType dt);
 DataType dtToArray(DataType dt, int dimension);
 bool dtIsArray(DataType dt);
+bool dtIsStruct(DataType dt);
+bool dtIsStructRef(DataType dt);
+Struct* dtGetStruct(Structs* sts, DataType dt);
 
 bool dtCompare(DataType a, DataType b);
 int dtMaxDimensions();

+ 1 - 0
tests/struct/bool_reference

@@ -9,6 +9,7 @@ void main() {
     print a;
     print b;
     test(&a);
+    test(&b);
     print a;
     print b;
 }

+ 1 - 0
tests/struct/float_reference

@@ -10,6 +10,7 @@ void main() {
     print a;
     print b;
     test(&a);
+    test(&b);
     print a;
     print b;
 }

+ 1 - 0
tests/struct/int_reference

@@ -10,6 +10,7 @@ void main() {
     print a;
     print b;
     test(&a);
+    test(&b);
     print a;
     print b;
 }

+ 38 - 0
tests/struct/pass_struct

@@ -0,0 +1,38 @@
+struct A {
+    int i;
+    bool b;
+};
+
+void test(A a) {
+    print a.i;
+    print a.b;
+    
+    a.i = 2;
+    a.b = false;
+    
+    print a.i;
+    print a.b;
+}
+
+void test(A& a) {
+    print a.i;
+    print a.b;
+    
+    a.i = 2;
+    a.b = false;
+    
+    print a.i;
+    print a.b;
+}
+
+void main() {
+    A a;
+    a.i = 3;
+    a.b = true;
+    test(a);
+    print a.i;
+    print a.b;
+    test(&a);
+    print a.i;
+    print a.b;
+}

+ 12 - 0
tests/struct/pass_struct.out

@@ -0,0 +1,12 @@
+3
+true
+2
+false
+3
+true
+3
+true
+2
+false
+2
+false

+ 1 - 1
tokenizer/Tokenizer.c

@@ -52,7 +52,7 @@ static bool tParseLiteral(int c) {
     int index = 1;
     char buffer[64];
     buffer[0] = c;
-    while(isLetter(fPeek()) || fPeek() == '.') {
+    while(isLetter(fPeek())) {
         if(index >= 63) {
             tError("literal is too long");
             return false;

+ 2 - 2
utils/ByteCodePrinter.c

@@ -110,8 +110,8 @@ static void btPrintFloat(const char* op) {
 #define PRINT_TYPES(TYPE)                                                      \
     PRINT_OP_INT(OP_LOAD_##TYPE);                                              \
     PRINT_OP_INT(OP_STORE_##TYPE);                                             \
-    PRINT_OP_INT(OP_REF_LOAD_##TYPE);                                          \
-    PRINT_OP_INT(OP_REF_STORE_##TYPE);                                         \
+    PRINT_OP(OP_REF_LOAD_##TYPE);                                              \
+    PRINT_OP(OP_REF_STORE_##TYPE);                                             \
     PRINT_OP_INT(OP_RETURN_##TYPE);                                            \
     PRINT_OP(OP_EQUAL_##TYPE);
 

+ 13 - 4
utils/Variables.c

@@ -10,9 +10,6 @@ void vInit(Variables* v) {
 }
 
 void vDelete(Variables* v) {
-    for(int i = 0; i < v->entries; i++) {
-        free(v->data[i].name);
-    }
     free(v->data);
 }
 
@@ -33,13 +30,25 @@ Variable* vSearchScope(Variables* v, const char* s) {
     return vSearchUntil(v, s, v->scope);
 }
 
+bool vSearchStruct(Variable* v, Structs* sts, Struct* st, const char* s) {
+    for(int i = 0; i < st->amount; i++) {
+        if(strcmp(st->vars[i].name, s) == 0) {
+            v->name = s;
+            v->type = st->vars[i].type;
+            return false;
+        }
+        v->address += dtGetSize(st->vars[i].type, sts);
+    }
+    return true;
+}
+
 Variable* vAdd(Variables* v, const char* s, DataType type, Structs* sts) {
     if(v->entries >= v->capacity) {
         v->capacity *= 2;
         v->data = realloc(v->data, sizeof(Variable) * v->capacity);
     }
     int index = v->entries++;
-    v->data[index] = (Variable){strdup(s), type, v->address};
+    v->data[index] = (Variable){s, type, v->address};
     v->address += dtGetSize(type, sts);
     if(v->address > v->maxAddress) {
         v->maxAddress = v->address;

+ 2 - 1
utils/Variables.h

@@ -12,7 +12,7 @@ typedef struct {
 } Scope;
 
 typedef struct {
-    char* name;
+    const char* name;
     DataType type;
     int address;
 } Variable;
@@ -29,6 +29,7 @@ typedef struct {
 void vInit(Variables* v);
 void vDelete(Variables* v);
 Variable* vSearch(Variables* v, const char* s);
+bool vSearchStruct(Variable* v, Structs* sts, Struct* st, const char* s);
 Variable* vSearchScope(Variables* v, const char* s);
 Variable* vAdd(Variables* v, const char* s, DataType type, Structs* sts);
 void vReset(Variables* v);

+ 2 - 3
vm/Script.c

@@ -248,9 +248,8 @@ static void sReturn(Script* sc) {
     if(sReadInt(sc, &bytes) && sPopInt(sc, &varIndex)) {
         sc->stackVarIndex = varIndex;
         sFree(sc, bytes);
-        int returnIndex;
-        if(sPopInt(sc, &returnIndex)) {
-            sc->readIndex = returnIndex;
+        if(!sPopInt(sc, &sc->readIndex) || sc->readIndex < 0) {
+            sError(sc, "read index is corrupt");
         }
     }
 }