Browse Source

functions can return pointers

Kajetan Johannes Hammerle 3 years ago
parent
commit
45806abc7b
6 changed files with 72 additions and 20 deletions
  1. 15 20
      Compiler.c
  2. 38 0
      tests/types/types
  3. 8 0
      tests/types/types.out
  4. 1 0
      utils/ByteCodePrinter.c
  5. 1 0
      vm/Operation.h
  6. 9 0
      vm/Script.c

+ 15 - 20
Compiler.c

@@ -198,6 +198,15 @@ static const char* cReadString() {
     return literal;
 }
 
+static DataType cReadStruct() {
+    const char* name = cReadString();
+    Struct* st = stsSearch(&structs, name);
+    if(st == NULL) {
+        cError("struct %s does not exist");
+    }
+    return dtStruct(st);
+}
+
 static DataType cExpression();
 
 static void cLoadRef(DataType type) {
@@ -311,16 +320,7 @@ static DataType cReadType() {
         case T_INT: dt = dtInt(); break;
         case T_BOOL: dt = dtBool(); break;
         case T_FLOAT: dt = dtFloat(); break;
-        case T_LITERAL:
-            {
-                const char* name = cReadString();
-                Struct* st = stsSearch(&structs, name);
-                if(st == NULL) {
-                    cError("struct %s does not exist");
-                }
-                dt = dtStruct(st);
-                break;
-            }
+        case T_LITERAL: dt = cReadStruct(); break;
         default: cUnexpectedToken(t);
     }
     return cExtendType(dt);
@@ -741,6 +741,8 @@ static void cReturn() {
         cAddReturn(OP_RETURN_BOOL);
     } else if(dtCompare(dt, dtFloat())) {
         cAddReturn(OP_RETURN_FLOAT);
+    } else if(dtIsPointer(dt)) {
+        cAddReturn(OP_RETURN_POINTER);
     } else {
         cError("cannot return %s", cGetName(dt));
     }
@@ -996,16 +998,7 @@ static void cFunctionArgument(Function* f) {
         case T_INT: cFunctionAddArgument(f, dtInt()); break;
         case T_FLOAT: cFunctionAddArgument(f, dtFloat()); break;
         case T_BOOL: cFunctionAddArgument(f, dtBool()); break;
-        case T_LITERAL:
-            {
-                const char* structName = cReadString();
-                Struct* st = stsSearch(&structs, structName);
-                if(st == NULL) {
-                    cError("struct %s does not exist");
-                }
-                cFunctionAddArgument(f, dtStruct(st));
-                break;
-            }
+        case T_LITERAL: cFunctionAddArgument(f, cReadStruct()); break;
         default: cUnexpectedToken(t);
     }
 }
@@ -1068,6 +1061,7 @@ static void cBuildFunction(Function* f, DataType rType) {
 }
 
 static void cFunction(DataType rType) {
+    rType = cExtendType(rType);
     Function f;
     cBuildFunction(&f, rType);
     Function* found = fsSearch(&functions, &f);
@@ -1121,6 +1115,7 @@ static void cGlobalScope(Token t) {
         case T_INT: cFunction(dtInt()); break;
         case T_BOOL: cFunction(dtBool()); break;
         case T_FLOAT: cFunction(dtFloat()); break;
+        case T_LITERAL: cFunction(cReadStruct()); break;
         case T_STRUCT: cStruct(); break;
         default: cUnexpectedToken(t);
     }

+ 38 - 0
tests/types/types

@@ -1,3 +1,7 @@
+struct A {
+    int a;
+};
+
 int intFunction() {
     return 1;
 }
@@ -10,6 +14,22 @@ float floatFunction() {
     return 2.0;
 }
 
+int* intFunction(int* i) {
+    return i;
+}
+
+bool* boolFunction(bool* b) {
+    return b;
+}
+
+float* floatFunction(float* f) {
+    return f;
+}
+
+A* structFunction(A* a) {
+    return a;
+}
+
 void main() {
     int i = intFunction();
     print i;
@@ -17,4 +37,22 @@ void main() {
     print b;
     float f = floatFunction();
     print f;
+    
+    print *intFunction(&i);
+    print *boolFunction(&b);
+    print *floatFunction(&f);
+    
+    print intFunction(&i)[0];
+    print boolFunction(&b)[0];
+    print floatFunction(&f)[0];
+    
+    A a;
+    a.a = 53453;
+    print structFunction(&a)->a;
+    structFunction(&a)->a = 123443;
+    structFunction(&a)[0].a += 3;
+    (*structFunction(&a)).a += 1;
+    (&a)->a += 4;
+    (*(&a)).a += 5;
+    print structFunction(&a)->a;
 }

+ 8 - 0
tests/types/types.out

@@ -1,3 +1,11 @@
 1
 true
 2.00
+1
+true
+2.00
+1
+true
+2.00
+53453
+123456

+ 1 - 0
utils/ByteCodePrinter.c

@@ -165,6 +165,7 @@ static void btConsumeOperation() {
         PRINT_OP_INT(OP_PEEK_TRUE_GOTO);
         PRINT_OP_INT2(OP_GOSUB);
         PRINT_OP_INT(OP_RETURN);
+        PRINT_OP_INT(OP_RETURN_POINTER);
         PRINT_OP_INT2(OP_RESERVE);
         PRINT_OP_INT(OP_DEREFERENCE_VAR);
         PRINT_OP(OP_REFERENCE);

+ 1 - 0
vm/Operation.h

@@ -40,6 +40,7 @@ typedef enum Operation {
     OP_GOSUB,
     OP_RETURN,
     TYPE_OPERATION(RETURN),
+    OP_RETURN_POINTER,
     OP_RESERVE,
     OP_LOAD,
     TYPE_OPERATION(LOAD),

+ 9 - 0
vm/Script.c

@@ -263,6 +263,14 @@ static void sReturn(Script* sc) {
     }
 }
 
+static void sReturnPointer(Script* sc) {
+    Pointer p;
+    if(sPopPointer(sc, &p)) {
+        sReturn(sc);
+        sPushPointer(sc, &p);
+    }
+}
+
 #define RETURN(type, Type)                                                     \
     {                                                                          \
         type value;                                                            \
@@ -509,6 +517,7 @@ static void sConsumeInstruction(Script* sc) {
         case OP_PEEK_TRUE_GOTO: sPeekTrueGoTo(sc); break;
         case OP_GOSUB: sGoSub(sc); break;
         case OP_RETURN: sReturn(sc); break;
+        case OP_RETURN_POINTER: sReturnPointer(sc); break;
         case OP_RESERVE: sReserveBytes(sc); break;
         case OP_DEREFERENCE_VAR: sDereference(sc); break;
         case OP_REFERENCE: sLoad(sc, sizeof(Pointer)); break;