Я ПІДТРИМУЮ ТІЛЬКИ ДОБАВЛЕННЯ ТА МНОЖЕННЯ ДЛЯ МОГО FFGEMM
#include
#include
#include
#include
#include
enum OperationEnum{ADD,MULTIPLY,OPERATION_LENGTH};
typedef struct value_struct *Value;
struct value_struct
{
mpz_t number;
mpz_t gradient;
Value children[2];
int numberOfChildren;
int operationIndex;
void (*backward)(struct value_struct*);
};
Value CreateNewValue(mpz_t value)
{
Value newValue = malloc(sizeof(*newValue));
mpz_init(newValue->number);mpz_set(newValue->number, value);
mpz_init(newValue->gradient);mpz_set_ui(newValue->gradient, 0);
newValue->children[0] = NULL;
newValue->children[1] = NULL;
newValue->numberOfChildren = 0;
newValue->operationIndex = 0;
newValue->backward = NULL;
return newValue;
}
void PrintValue(Value value)
{
if(value != NULL)
{
gmp_printf("Значення(value = %n", value->number, value->gradient);
}
}
void DestroyValue(Value value)
{
mpz_clear(value->number);
mpz_clear(value->gradient);
free(value);
}
void AddBack(Value v)
{
mpz_add(v->children[0]->gradient, v->children[0]->gradient, v->gradient);
mpz_add(v->children[1]->gradient, v->children[1]->gradient, v->gradient);
//ClipGradient(v->children[0], -10.0, 10.0);
//ClipGradient(v->children[1], -10.0, 10.0);
}
void MultiplyBack(Value v)
{
mpz_t temporary; mpz_init(temporary);
mpz_mul(temporary, v->children[1]->number, v->gradient);
mpz_add(v->children[0]->gradient, v->children[0]->gradient, temporary);
mpz_mul(temporary, v->children[0]->number, v->gradient);
mpz_add(v->children[1]->gradient, v->children[1]->gradient, temporary);
mpz_clear(temporary);
//ClipGradient(v->children[0], -10.0, 10.0);
//ClipGradient(v->children[1], -10.0, 10.0);
}
Value AddValues(Value a, Value b)
{
mpz_t temporary; mpz_init(temporary);
mpz_add(temporary, a->number, b->number);
Value value = CreateNewValue(temporary);
value->children[0] = a;
value->children[1] = b;
value->numberOfChildren = 2;
value->operationIndex = ADD;
value->backward = AddBack;
mpz_clear(temporary);
return value;
}
Value MultiplyValues(Value a, Value b)
{
mpz_t temporary; mpz_init(temporary);
mpz_mul(temporary, a->number, b->number);
Value value = CreateNewValue(temporary);
value->children[0] = a;
value->children[1] = b;
value->numberOfChildren = 2;
value->operationIndex = MULTIPLY;
value->backward = MultiplyBack;
mpz_clear(temporary);
return value;
}
void build_topo(Value v, Value* topo, int* topo_size, Value* visited, int* visited_size) {
for (int i = 0; i < *visited_size; ++i) {
if (visited[i] == v) return;
}
visited[*visited_size] = v;
(*visited_size)++;
// printf("%i\n", v->n_children);
for (int i = 0; i < v->numberOfChildren; ++i) {
// printf("child of %f\n", v->val);
for (int i = 0; i < v->numberOfChildren; ++i) {
// print_value(v->children[i]);
}
// printf("\n\n");
build_topo(v->children[i], topo, topo_size, visited, visited_size);
}
// printf("topo size = %i, node.val = %.2f\n", *topo_size, v->val);
topo[*topo_size] = v;
(*topo_size)++;
}
void backward(Value root)
{
Value topo[1000];
int topo_size = 0;
Value visited[1000];
int visited_size = 0;
build_topo(root, topo, &topo_size, visited, &visited_size);
mpz_set_si(root->gradient, 1);
for (int i = topo_size - 1; i >= 0; --i) {
// printf("%.2f", topo[i]->val);
// printf("\n");
if (topo[i]->backward) {
topo[i]->backward(topo[i]);
}
}
}
typedef Value (*f)(Value, Value);
f ForwardOps[] = {&AddValues, &MultiplyValues};
typedef void (*g)(Value);
g BackwardOps[] = {&AddBack, &MultiplyBack};
void TestGrad()
{
mpz_t integer; mpz_init(integer);
mpz_set_ui(integer, 5);
Value a = CreateNewValue(integer);
mpz_set_si(integer, -50);
Value b = CreateNewValue(integer);
mpz_set_ui(integer, 15);
Value c = CreateNewValue(integer);
Value e = ForwardOps[MULTIPLY](a,b);
Value d = ForwardOps[ADD](e,c);
mpz_set_si(integer, -5);
Value f = CreateNewValue(integer);
Value L = ForwardOps[MULTIPLY](d,f);
backward(L);
PrintValue(a);
PrintValue(b);
PrintValue(c);
PrintValue(d);
PrintValue(e);
PrintValue(f);
PrintValue(L);
DestroyValue(a);
DestroyValue(b);
DestroyValue(c);
DestroyValue(d);
DestroyValue(e);
DestroyValue(f);
DestroyValue(L);
mpz_clear(integer);
}
Перекладено з: [Autograd in C with LibGMP](https://kibichomurage.medium.com/autograd-in-c-with-libgmp-1603639e95ef?source=rss------c-5)