Skip to content

Commit 1419bab

Browse files
author
psainics
committed
Add Oracle UPSERT
1 parent ec0383c commit 1419bab

7 files changed

Lines changed: 340 additions & 26 deletions

File tree

database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,7 @@ protected void upsertOperation(PreparedStatement stmt) throws SQLException {
314314
}
315315

316316
private boolean fillUpdateParams(List<String> updatedKeyList, ColumnType columnType) {
317-
if (operationName.equals(Operation.UPDATE) && updatedKeyList.contains(columnType.getName())) {
318-
return true;
319-
}
320-
return false;
317+
return operationName.equals(Operation.UPDATE) && updatedKeyList.contains(columnType.getName());
321318
}
322319

323320
private Schema getNonNullableSchema(Schema.Field field) {
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright © 2026 Cask Data, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
* use this file except in compliance with the License. You may obtain a copy of
6+
* the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations under
14+
* the License.
15+
*/
16+
17+
package io.cdap.plugin.oracle;
18+
19+
import io.cdap.plugin.db.sink.ETLDBOutputFormat;
20+
21+
/**
22+
* Class that extends {@link ETLDBOutputFormat} to implement the abstract methods
23+
*/
24+
public class OracleETLDBOutputFormat extends ETLDBOutputFormat {
25+
26+
/**
27+
* This method is used to construct the upsert query for Oracle using MERGE statement.
28+
* Example - MERGE INTO my_table target
29+
* USING (SELECT ? AS id, ? AS name, ? AS age FROM dual) source
30+
* ON (target.id = source.id)
31+
* WHEN MATCHED THEN UPDATE SET target.name = source.name, target.age = source.age
32+
* WHEN NOT MATCHED THEN INSERT (id, name, age) VALUES (source.id, source.name, source.age)
33+
* @param table - Name of the table
34+
* @param fieldNames - All the columns of the table
35+
* @param listKeys - The columns used as keys for matching
36+
* @return Upsert query in the form of string
37+
*/
38+
@Override
39+
public String constructUpsertQuery(String table, String[] fieldNames, String[] listKeys) {
40+
if (listKeys == null) {
41+
throw new IllegalArgumentException("Column names to be updated should not be null");
42+
} else if (fieldNames == null) {
43+
throw new IllegalArgumentException("Field names should not be null");
44+
} else {
45+
StringBuilder query = new StringBuilder();
46+
47+
// MERGE INTO target_table target
48+
query.append("MERGE INTO ").append(table).append(" target ");
49+
50+
// USING (SELECT ? AS col1, ? AS col2, ... FROM dual) source
51+
query.append("USING (SELECT ");
52+
for (int i = 0; i < fieldNames.length; ++i) {
53+
query.append("? AS ").append(fieldNames[i]);
54+
if (i != fieldNames.length - 1) {
55+
query.append(", ");
56+
}
57+
}
58+
query.append(" FROM dual) source ");
59+
60+
// ON (target.key1 = source.key1 AND target.key2 = source.key2 ...)
61+
query.append("ON (");
62+
for (int i = 0; i < listKeys.length; ++i) {
63+
query.append("target.").append(listKeys[i]).append(" = source.").append(listKeys[i]);
64+
if (i != listKeys.length - 1) {
65+
query.append(" AND ");
66+
}
67+
}
68+
query.append(") ");
69+
70+
// WHEN MATCHED THEN UPDATE SET target.col1 = source.col1, target.col2 = source.col2 ...
71+
// Only update non-key columns
72+
query.append("WHEN MATCHED THEN UPDATE SET ");
73+
boolean firstUpdateColumn = true;
74+
for (String fieldName : fieldNames) {
75+
boolean isKeyColumn = false;
76+
for (String listKey : listKeys) {
77+
String listKeyNoQuote = listKey.replace("\"", "");
78+
if (listKeyNoQuote.equals(fieldName)) {
79+
isKeyColumn = true;
80+
break;
81+
}
82+
}
83+
if (!isKeyColumn) {
84+
if (!firstUpdateColumn) {
85+
query.append(", ");
86+
}
87+
query.append("target.").append(fieldName).append(" = source.").append(fieldName);
88+
firstUpdateColumn = false;
89+
}
90+
}
91+
92+
// WHEN NOT MATCHED THEN INSERT (col1, col2, ...) VALUES (source.col1, source.col2, ...)
93+
query.append(" WHEN NOT MATCHED THEN INSERT (");
94+
for (int i = 0; i < fieldNames.length; ++i) {
95+
query.append(fieldNames[i]);
96+
if (i != fieldNames.length - 1) {
97+
query.append(", ");
98+
}
99+
}
100+
query.append(") VALUES (");
101+
for (int i = 0; i < fieldNames.length; ++i) {
102+
query.append("source.").append(fieldNames[i]);
103+
if (i != fieldNames.length - 1) {
104+
query.append(", ");
105+
}
106+
}
107+
query.append(")");
108+
109+
return query.toString();
110+
}
111+
}
112+
113+
@Override
114+
public String constructUpdateQuery(String table, String[] fieldNames, String[] listKeys) {
115+
// Oracle JDBC does not accept a trailing semicolon in prepared statements.
116+
String query = super.constructUpdateQuery(table, fieldNames, listKeys);
117+
if (query.endsWith(";")) {
118+
return query.substring(0, query.length() - 1);
119+
}
120+
return query;
121+
}
122+
}

oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSink.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.cdap.cdap.api.annotation.MetadataProperty;
2424
import io.cdap.cdap.api.annotation.Name;
2525
import io.cdap.cdap.api.annotation.Plugin;
26+
import io.cdap.cdap.api.data.batch.Output;
2627
import io.cdap.cdap.api.data.format.StructuredRecord;
2728
import io.cdap.cdap.etl.api.FailureCollector;
2829
import io.cdap.cdap.etl.api.batch.BatchSink;
@@ -31,6 +32,7 @@
3132
import io.cdap.plugin.common.Asset;
3233
import io.cdap.plugin.common.ConfigUtil;
3334
import io.cdap.plugin.common.LineageRecorder;
35+
import io.cdap.plugin.common.batch.sink.SinkOutputFormatProvider;
3436
import io.cdap.plugin.common.db.DBErrorDetailsProvider;
3537
import io.cdap.plugin.db.DBRecord;
3638
import io.cdap.plugin.db.SchemaReader;
@@ -60,7 +62,8 @@ public OracleSink(OracleSinkConfig oracleSinkConfig) {
6062

6163
@Override
6264
protected DBRecord getDBRecord(StructuredRecord output) {
63-
return new OracleSinkDBRecord(output, columnTypes);
65+
return new OracleSinkDBRecord(output, columnTypes, oracleSinkConfig.getOperationName(),
66+
oracleSinkConfig.getRelationTableKey());
6467
}
6568

6669
@Override
@@ -72,6 +75,13 @@ protected FieldsValidator getFieldsValidator() {
7275
protected SchemaReader getSchemaReader() {
7376
return new OracleSinkSchemaReader();
7477
}
78+
79+
@Override
80+
protected void addOutputContext(BatchSinkContext context) {
81+
context.addOutput(Output.of(oracleSinkConfig.getReferenceName(),
82+
new SinkOutputFormatProvider(OracleETLDBOutputFormat.class, getConfiguration())));
83+
}
84+
7585
@Override
7686
protected LineageRecorder getLineageRecorder(BatchSinkContext context) {
7787
String fqn = DBUtils.constructFQN("oracle",

oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSinkDBRecord.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.cdap.cdap.api.data.format.StructuredRecord;
2020
import io.cdap.cdap.api.data.schema.Schema;
2121
import io.cdap.plugin.db.ColumnType;
22+
import io.cdap.plugin.db.Operation;
2223
import io.cdap.plugin.db.SchemaReader;
2324

2425
import java.sql.PreparedStatement;
@@ -31,9 +32,12 @@
3132
*/
3233
public class OracleSinkDBRecord extends OracleSourceDBRecord {
3334

34-
public OracleSinkDBRecord(StructuredRecord record, List<ColumnType> columnTypes) {
35+
public OracleSinkDBRecord(StructuredRecord record, List<ColumnType> columnTypes, Operation operationName,
36+
String relationTableKey) {
3537
this.record = record;
3638
this.columnTypes = columnTypes;
39+
this.operationName = operationName;
40+
this.relationTableKey = relationTableKey;
3741
}
3842

3943
@Override
@@ -50,4 +54,13 @@ protected void insertOperation(PreparedStatement stmt) throws SQLException {
5054
writeToDB(stmt, field, fieldIndex);
5155
}
5256
}
57+
58+
@Override
59+
protected void upsertOperation(PreparedStatement stmt) throws SQLException {
60+
for (int fieldIndex = 0; fieldIndex < columnTypes.size(); fieldIndex++) {
61+
ColumnType columnType = columnTypes.get(fieldIndex);
62+
Schema.Field field = record.getSchema().getField(columnType.getName());
63+
writeToDB(stmt, field, fieldIndex);
64+
}
65+
}
5366
}

oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
import com.google.common.io.ByteStreams;
2020
import io.cdap.cdap.api.common.Bytes;
2121
import io.cdap.cdap.api.data.format.StructuredRecord;
22+
import io.cdap.cdap.api.data.format.StructuredRecord.Builder;
2223
import io.cdap.cdap.api.data.schema.Schema;
24+
import io.cdap.cdap.api.data.schema.Schema.Field;
25+
import io.cdap.cdap.api.data.schema.Schema.LogicalType;
26+
import io.cdap.cdap.api.data.schema.Schema.Type;
2327
import io.cdap.cdap.etl.api.validation.InvalidStageException;
2428
import io.cdap.plugin.db.ColumnType;
2529
import io.cdap.plugin.db.DBRecord;
2630
import io.cdap.plugin.db.SchemaReader;
31+
import org.apache.hadoop.io.Writable;
32+
import org.apache.hadoop.mapreduce.lib.db.DBWritable;
2733

2834
import java.io.IOException;
2935
import java.io.InputStream;
@@ -45,8 +51,8 @@
4551
import java.util.List;
4652

4753
/**
48-
* Oracle Source implementation {@link org.apache.hadoop.mapreduce.lib.db.DBWritable} and
49-
* {@link org.apache.hadoop.io.Writable}.
54+
* Oracle Source implementation {@link DBWritable} and
55+
* {@link Writable}.
5056
*/
5157
public class OracleSourceDBRecord extends DBRecord {
5258

@@ -77,11 +83,11 @@ protected SchemaReader getSchemaReader() {
7783
public void readFields(ResultSet resultSet) throws SQLException {
7884
Schema schema = getSchema();
7985
ResultSetMetaData metadata = resultSet.getMetaData();
80-
StructuredRecord.Builder recordBuilder = StructuredRecord.builder(schema);
86+
Builder recordBuilder = StructuredRecord.builder(schema);
8187

8288
// All LONG or LONG RAW columns have to be retrieved from the ResultSet prior to all the other columns.
8389
// Otherwise, we will face java.sql.SQLException: Stream has already been closed
84-
for (Schema.Field field : schema.getFields()) {
90+
for (Field field : schema.getFields()) {
8591
// Index of a field in the schema may not be same in the ResultSet,
8692
// hence find the field by name in the given resultSet
8793
int columnIndex = resultSet.findColumn(field.getName());
@@ -91,7 +97,7 @@ public void readFields(ResultSet resultSet) throws SQLException {
9197
}
9298

9399
// Read fields of other types
94-
for (Schema.Field field : schema.getFields()) {
100+
for (Field field : schema.getFields()) {
95101
// Index of a field in the schema may not be same in the ResultSet,
96102
// hence find the field by name in the given resultSet
97103
int columnIndex = resultSet.findColumn(field.getName());
@@ -104,7 +110,7 @@ record = recordBuilder.build();
104110
}
105111

106112
@Override
107-
protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field,
113+
protected void handleField(ResultSet resultSet, Builder recordBuilder, Field field,
108114
int columnIndex, int sqlType, int sqlPrecision, int sqlScale) throws SQLException {
109115
if (OracleSourceSchemaReader.ORACLE_TYPES.contains(sqlType) || sqlType == Types.NCLOB) {
110116
handleOracleSpecificType(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale);
@@ -116,8 +122,21 @@ protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordB
116122
@Override
117123
protected void writeNonNullToDB(PreparedStatement stmt, Schema fieldSchema,
118124
String fieldName, int fieldIndex) throws SQLException {
119-
int sqlType = columnTypes.get(fieldIndex).getType();
120125
int sqlIndex = fieldIndex + 1;
126+
int sqlType = Types.OTHER;
127+
boolean isFieldTypeFound = false;
128+
// avoid OOB exception in case of mismatch between columnTypes and record schema fields
129+
for (ColumnType columnType : columnTypes) {
130+
if (columnType.getName().equals(fieldName)) {
131+
sqlType = columnType.getType();
132+
isFieldTypeFound = true;
133+
break;
134+
}
135+
}
136+
137+
if (!isFieldTypeFound) {
138+
throw new IllegalArgumentException("Unable to find the column type for field '" + fieldName + "'.");
139+
}
121140

122141
// TIMESTAMP and TIMESTAMPTZ types needs to be handled using the specific oracle types to ensure that the data
123142
// inserted matches with the provided value. As Oracle driver internally alters the values provided
@@ -126,7 +145,7 @@ protected void writeNonNullToDB(PreparedStatement stmt, Schema fieldSchema,
126145
// More details here : https://docs.oracle.com/cd/E13222_01/wls/docs91/jdbc_drivers/oracle.html
127146
// Handle the case when TimestampTZ type is set to CDAP String type or Timestamp type
128147
if (sqlType == OracleSourceSchemaReader.TIMESTAMP_TZ) {
129-
if (Schema.Type.STRING.equals(fieldSchema.getType())) {
148+
if (Type.STRING.equals(fieldSchema.getType())) {
130149
// Deprecated: Handle the case when the TimestampTZ is mapped to CDAP String type
131150
String timestampString = record.get(fieldName);
132151
Object timestampTZ = createOracleTimestampWithTimeZone(stmt.getConnection(), timestampString);
@@ -140,21 +159,21 @@ protected void writeNonNullToDB(PreparedStatement stmt, Schema fieldSchema,
140159
stmt.setObject(sqlIndex, timestampWithTimeZone);
141160
}
142161
} else if (sqlType == OracleSourceSchemaReader.TIMESTAMP_LTZ) {
143-
if (Schema.LogicalType.TIMESTAMP_MICROS.equals(fieldSchema.getLogicalType())) {
162+
if (LogicalType.TIMESTAMP_MICROS.equals(fieldSchema.getLogicalType())) {
144163
// Deprecated: Handle the case when the TimestampLTZ is mapped to CDAP Timestamp type
145164
ZonedDateTime timestamp = record.getTimestamp(fieldName);
146165
String timestampString = Timestamp.valueOf(timestamp.toLocalDateTime()).toString();
147166
Object timestampWithTimeZone = createOracleTimestampWithLocalTimeZone(stmt.getConnection(), timestampString);
148167
stmt.setObject(sqlIndex, timestampWithTimeZone);
149-
} else if (Schema.LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) {
168+
} else if (LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) {
150169
// Handle the case when the TimestampLTZ is mapped to CDAP Datetime type
151170
LocalDateTime localDateTime = record.getDateTime(fieldName);
152171
String timestampString = Timestamp.valueOf(localDateTime).toString();
153172
Object timestampWithTimeZone = createOracleTimestampWithLocalTimeZone(stmt.getConnection(), timestampString);
154173
stmt.setObject(sqlIndex, timestampWithTimeZone);
155174
}
156175
} else if (sqlType == Types.TIMESTAMP) {
157-
if (Schema.LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) {
176+
if (LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) {
158177
// Handle the case when Timestamp is mapped to CDAP Datetime type.
159178
LocalDateTime localDateTime = record.getDateTime(fieldName);
160179
String timestampString = Timestamp.valueOf(localDateTime).toString();
@@ -257,7 +276,7 @@ private byte[] getBfileBytes(ResultSet resultSet, String columnName) throws SQLE
257276
}
258277
}
259278

260-
private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field,
279+
private void handleOracleSpecificType(ResultSet resultSet, Builder recordBuilder, Field field,
261280
int columnIndex, int sqlType, int precision, int scale)
262281
throws SQLException {
263282
Schema nonNullSchema = field.getSchema().isNullable() ?
@@ -270,7 +289,7 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil
270289
recordBuilder.set(field.getName(), resultSet.getString(columnIndex));
271290
break;
272291
case OracleSourceSchemaReader.TIMESTAMP_TZ:
273-
if (Schema.Type.STRING.equals(nonNullSchema.getType())) {
292+
if (Type.STRING.equals(nonNullSchema.getType())) {
274293
recordBuilder.set(field.getName(), resultSet.getString(columnIndex));
275294
} else {
276295
// In case of TimestampTZ datatype the getTimestamp(index, Calendar) method call does not
@@ -298,7 +317,7 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil
298317
case Types.TIMESTAMP:
299318
// Since Oracle Timestamp type does not have any timezone information, it should be converted into the
300319
// CDAP Datetime type.
301-
if (Schema.LogicalType.DATETIME.equals(nonNullSchema.getLogicalType())) {
320+
if (LogicalType.DATETIME.equals(nonNullSchema.getLogicalType())) {
302321
Timestamp timestamp = resultSet.getTimestamp(columnIndex);
303322
if (timestamp != null) {
304323
recordBuilder.setDateTime(field.getName(), timestamp.toLocalDateTime());
@@ -315,7 +334,7 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil
315334
// super.setField sets this '0000-12-31 09:00:00.000Z[UTC]' in the recordBuilder which is incorrect and the
316335
// correct value should be '0001-01-01 09:00:00.000Z[UTC]'.
317336
Object timeStampObj = resultSet.getObject(columnIndex);
318-
if (Schema.LogicalType.DATETIME.equals(nonNullSchema.getLogicalType())) {
337+
if (LogicalType.DATETIME.equals(nonNullSchema.getLogicalType())) {
319338
Timestamp timestampLTZ = resultSet.getTimestamp(columnIndex);
320339
if (timestampLTZ != null) {
321340
recordBuilder.setDateTime(field.getName(),
@@ -351,7 +370,7 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil
351370
if (precision == 0) {
352371
Schema nonNullableSchema = field.getSchema().isNullable() ?
353372
field.getSchema().getNonNullable() : field.getSchema();
354-
if (Schema.LogicalType.DECIMAL.equals(nonNullableSchema.getLogicalType())) {
373+
if (LogicalType.DECIMAL.equals(nonNullableSchema.getLogicalType())) {
355374
// Handle the field using the schema set in the output schema
356375
BigDecimal decimal = resultSet.getBigDecimal(columnIndex, getScale(field.getSchema()));
357376
recordBuilder.setDecimal(field.getName(), decimal);
@@ -382,8 +401,8 @@ private boolean isLongOrLongRaw(int columnType) {
382401
return columnType == OracleSourceSchemaReader.LONG || columnType == OracleSourceSchemaReader.LONG_RAW;
383402
}
384403

385-
private void readField(int columnIndex, ResultSetMetaData metadata, ResultSet resultSet, Schema.Field field,
386-
StructuredRecord.Builder recordBuilder) throws SQLException {
404+
private void readField(int columnIndex, ResultSetMetaData metadata, ResultSet resultSet, Field field,
405+
Builder recordBuilder) throws SQLException {
387406
int sqlType = metadata.getColumnType(columnIndex);
388407
int sqlPrecision = metadata.getPrecision(columnIndex);
389408
int sqlScale = metadata.getScale(columnIndex);

0 commit comments

Comments
 (0)